#! /usr/bin/env python

# Tracks file descriptors in strace output.

# Expects strace output to be post processed with:

# strace2tsv | sed 's/ <<\([^>]*\)>>\t/\t\1\t/g;s/ <<\([^>]*\)>>$/\t\1/'

import sys, getopt, re, copy, string

# global variables

input_file = None                       # input file name, None for stdin
fin = sys.stdin                         # file object for input
fout = sys.stdout                       # file object for output
sct = {}                                # maps syscall names to handlers
pt = {}                                 # maps a pid to a file descriptor table
# an object that contains information about the file descriptor.
counter = 0                             # used to generate tracker identifiers

def main():
    "handle argv processing"
    global input_file
    input_file = None                   # may be given in argv
    output_file = None                  # may be given with -o option in argv
    try:
        opts, args = getopt.getopt(sys.argv[1:], "ho:", ["help", "output="])
    except getopt.GetoptError:
        # bad options
        usage()
        sys.exit(2)
    for o, a in opts:
        if o in ("-h", "--help"):
            # print help information
            usage()
            sys.exit(0)
        elif o in ("-o", "--output"):
            # record output file name
            output_file = a
    if len(args) > 1:
        # too many command line arguments
        usage()
        sys.exit(2)
    elif len(args) == 1:
        # record input file name
        input_file = args[0]
    openio(output_file)

def usage():
    print "Usage: " + sys.argv[0] + " [-h] [-o OUTPUT] [INPUT]"

def openio(output_file):
    "open input and output as needed"
    global fin, fout
    fin = sys.stdin                     # default input
    fout = sys.stdout                   # default output
    if input_file:
        try:
            fin = file(input_file, 'r')
        except IOError:
            print "cannot open " + input_file + " for reading"
            sys.exit(1)
    if output_file:
        try:
            fout = file(output_file, 'w')
        except IOError:
            print "cannot open " + output_file + " for writing"
            sys.exit(1)
    run()

def run():
    "the main loop"
    global pt, counter, sct
    pt = {}                             # initialize the process table
    counter = 0                         # initialize tracker labeler
    lineno = 0
    init()                              # initalize syscall table
    while True:
        line = fin.readline()
        if not line:
            break
        lineno = lineno + 1
        record = line.strip().split("\t")
        nf = len(record)                # number of fields
        if nf > 1:
            pid = getpid(record[0])
            syscall = record[1]
            if pid >= 0:
                try:
                    fds = pt[pid]
                except KeyError:
                    fds = {}            # create file descriptor table
                    pt[pid] = fds       # give unknown pids stdio
                    sc = record[-1]
                    fds[0] = StreamTracker(pid, 0, "stdin", sc)
                    fds[1] = StreamTracker(pid, 1, "stdout", sc)
                    fds[2] = StreamTracker(pid, 2, "stderr", sc)
                try:
                    sct[syscall](pid, fds, record, nf, lineno)
                except KeyError:
                    pass
    for pid, fds in pt.iteritems():       # close all open file descriptors
        for fd, tracker in fds.iteritems():
            tracker.do_close()
        fds.clear()

def do_close(pid, fds, record, nf, lineno):
    "do the syscall and then delete the tracker from fds"
    if nf > 2 and geterrno(record, nf) == 0:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_close()
                del fds[fd]
            except KeyError:
                sys.stderr.write(str(pid)
                                 + ": close on non-existent file descriptor "
                                 + str(fd) + " on line " + str(lineno) + "\n")

def do_open(pid, fds, record, nf, lineno):
    "create an association between a file descriptor and a file"
    if nf > 6:
        fd = geterrno(record, nf)
        if fd >= 0:
            warn_fds_update(pid, fds, fd, record[1], lineno)
            fds[fd] = OpenTracker(pid, fd, record, nf)

def do_socket(pid, fds, record, nf, lineno):
    "create an association between a file descriptor and a socket"
    if nf > 6:
        fd = geterrno(record, nf)
        if fd >= 0:
            warn_fds_update(pid, fds, fd, record[1], lineno)
            fds[fd] = SocketTracker(pid, fd, -1, record, nf)

def do_pipe(pid, fds, record, nf, lineno):
    "create two file descriptors from a pipe syscall"
    errno = geterrno(record, nf)
    if errno == 0:
        pair = record[2][1:]            # strip leading bracket
        pair = pair.split(" ");
        if len(pair) > 2:
            readfd = getfd(pair[0])
            writefd = getfd(pair[2])
            if readfd >= 0 and writefd >= 0:
                reader = PipeTracker(pid, readfd, record[-1])
                writer = PipeTracker(pid, writefd, record[-1])
                reader.pal = writer.id
                writer.pal = reader.id
                warn_fds_update(pid, fds, readfd, record[1], lineno)
                fds[readfd] = reader
                warn_fds_update(pid, fds, writefd, record[1], lineno)
                fds[writefd] = writer

def warn_fds_update(pid, fds, fd, syscall, lineno):
    try:
        tracker = fds[fd]
        sys.stderr.write(str(pid) + ": " + syscall
                         + " over a previously existing"
                         + " file descriptor " + str(fd)
                         + " on line " + str(lineno) + "\n")
    except KeyError:
        pass

def do_socketpair(pid, fds, record, nf, lineno):
    "create two file descriptors from a socketpair syscall"
    errno = geterrno(record, nf)
    if errno == 0:
        pair = record[5][1:][:-1]       # strip brackets
        pair = pair.split(" ");
        pair[0] = pair[0][:-1]
        if len(pair) > 1:
            fstfd = getfd(pair[0])
            sndfd = getfd(pair[1])
            if fstfd >= 0 and sndfd >= 0:
                first = SocketPairTracker(pid, fstfd, record[-1])
                second = SocketPairTracker(pid, sndfd, record[-1])
                first.pal = second.id
                second.pal = first.id
                warn_fds_update(pid, fds, fstfd, record[1], lineno)
                fds[fstfd] = first
                warn_fds_update(pid, fds, sndfd, record[1], lineno)
                fds[sndfd] = second

# file descriptor duplication

def do_dup(pid, fds, record, nf, lineno):
    "duplicate a file descriptor"
    if nf > 4:
        newfd = geterrno(record, nf)
        oldfd = getfd(record[2])
        dup_tracker(pid, fds, oldfd, newfd, False)

def do_dup2(pid, fds, record, nf, lineno):
    "duplicate a file descriptor"
    if nf > 5:
        newfd = geterrno(record, nf)
        oldfd = getfd(record[2])
        if newfd != oldfd:
            dup_tracker(pid, fds, oldfd, newfd, True)

def do_fcntl64(pid, fds, record, nf, lineno):
    if nf > 5:
        newfd = geterrno(record, nf)
        oldfd = getfd(record[2])
        if oldfd >= 0:
            if record[4] == "F_SETFD" and newfd == 0:
                try:
                    fds[oldfd].do_fcntl64(record, nf)
                except KeyError:
                    warn_missing(pid, oldfd, record[1])
            elif record[4] == "F_DUPFD":
                dup_tracker(pid, fds, oldfd, newfd, True)

def dup_tracker(pid, fds, oldfd, newfd, close):
    if oldfd >= 0 and newfd >= 0:
        try:
            tracker = fds[newfd]
            if close:
                tracker.do_close()
            else:
                sys.stderr.write(str(pid) + ": dup "
                         + " over a previously existing"
                         + " file descriptor "
                         + str(newfd) + "\n")
        except KeyError:
            pass
        try:
            tracker = copy.copy(fds[oldfd])
            tracker.close_on_exec = False
            fds[newfd] = tracker
        except KeyError:
            sys.stderr.write(str(pid) + ": duplicating "
                         + " a non-existent file descriptor "
                         + str(oldfd) + "\n")

# syscall dispatchers to file descriptors

def do_read(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_read(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def do_write(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_write(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def do_bind(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_bind(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def do_accept(pid, fds, record, nf, lineno):
    if nf > 5:
        fd = getfd(record[2])
        newfd = geterrno(record, nf)
        if fd >= 0 and newfd >= 0:
            parent = -1
            try:
                server = fds[fd]
                parent = server.id
                server.do_accept(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])
            warn_fds_update(pid, fds, newfd, record[1], lineno)
            tracker = SocketTracker(pid, newfd, parent, record, nf)
            tracker.op = record[1]
            tracker.addr = record[4]
            fds[newfd] = tracker

def do_connect(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_connect(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def do_recv(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_recv(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def do_send(pid, fds, record, nf, lineno):
    if nf > 3:
        fd = getfd(record[2])
        if fd >= 0:
            try:
                fds[fd].do_send(record, nf)
            except KeyError:
                warn_missing(pid, fd, record[1])

def warn_missing(pid, fd, syscall):
    sys.stderr.write(str(pid) + ": " + syscall
                         + " applied to non-existent "
                         + " file descriptor "
                         + str(fd) + "\n")

# syscalls that do not involve file descriptors

def do_unlink(pid, fds, record, nf, lineno):
    if nf > 5:
        fout.write(str(pid))
        fout.write("\tUnlink\t")
        fout.write(record[2])
        fout.write("\t")
        fout.write(record[3])
        fout.write("\t")
        fout.write(str(geterrno(record, nf)))
        fout.write("\t")
        fout.write(record[-1])
        fout.write("\n")

def do_execve(pid, fds, record, nf, lineno):
    if nf > 5:
        fout.write(str(pid))
        fout.write("\tExec\t")
        fout.write(record[2])
        fout.write("\t")
        fout.write(record[3])
        fout.write("\t")
        fout.write(str(geterrno(record, nf)))
        fout.write("\t")
        fout.write(record[-1])
        fout.write("\n")
        remove = []                     # List of fds to be removed
        for fd, tracker in fds.iteritems():
            if tracker.close_on_exec:
                tracker.do_close()
                remove.append(fd)
        for fd in remove:
            del fds[fd]

def do_clone(pid, fds, record, nf, lineno):
    global pt
    if nf > 3:
        child = geterrno(record, nf)
        if child:
            share = string.find(record[3], "CLONE_FILES")
            fout.write(str(pid))
            if share < 0:           # clone file descriptors for child
                fds = copy.deepcopy(fds)
                for fd, tracker in fds.iteritems():
                    tracker.reset(child) # clear I/O history and update pid
                    pt[child] = fds     # add new pid to process table
                fout.write("\tClone\t")
            else:                       # share file descriptors
                pt[child] = fds
                fout.write("\tShare\t")
            fout.write(str(child))
            fout.write("\t")
            fout.write(record[-1])
            fout.write("\n")

def do_verbatim(pid, fds, record, nf, lineno):
    fout.write(str(pid))
    fout.write("\tMisc")
    for x in record[1:]:
        fout.write("\t")
        fout.write(x)
    fout.write("\n")

def getpid(str):
    "an alias for getfd"
    return getfd(str)

def getfd(str):
    "get an integer from a string or return -1"
    try:
        return int(str)
    except ValueError:
        return -1

def geterrno(record, nf):
    "get the error number in a record or return -1"
    if nf > 3:
        result = record[-2].split(" ")
        if len(result) > 0:
            try:
                return int(result[0])
            except ValueError:
                pass
    return -1

def init():
    global sct
    sct["close"] = do_close
    sct["open"] = do_open
    sct["socket"] = do_socket
    sct["pipe"] = do_pipe
    sct["socketpair"] = do_socketpair
    sct["dup"] = do_dup
    sct["dup2"] = do_dup2
    sct["fcntl64"] = do_fcntl64
    sct["read"] = do_read
    sct["write"] = do_write
    sct["bind"] = do_bind
    sct["accept"] = do_accept
    sct["connect"] = do_connect
    sct["recv"] = do_recv
    sct["send"] = do_send
    sct["unlink"] = do_unlink
    sct["execve"] = do_execve
    sct["clone"] = do_clone
    sct["setsid"] = do_verbatim
    sct["chdir"] = do_verbatim
    sct["umask"] = do_verbatim

# file descriptors

class Tracker:
    def __init__(self, pid, fd, sc):
        global counter
        self.id = counter
        counter = counter + 1
        self.fd = fd
        self.sc = sc
        self.close_on_exec = False
        self.reset(pid)

    def reset(self, pid):
        self.pid = pid
        self.read_sc = None
        self.write_sc = None
        self.read = None
        self.written = None

    def do_read(self, record, nf):
        self.read = "R"
        if nf > 3:
            self.read_sc = record[3]

    def do_write(self, record, nf):
        self.written = "W"
        if nf > 3:
            self.write_sc = record[3]

    def do_fcntl64(self, record, nf):
        if nf > 5 and record[4] == "F_SETFD" and record[5] == "FD_CLOEXEC":
            self.close_on_exec = True

    def do_close(self):
        fout.write("\t")
        if self.read:
            fout.write(self.read);
        fout.write("\t")
        if self.read_sc:
            fout.write(self.read_sc)
        fout.write("\t")
        if self.written:
            fout.write(self.written)
        fout.write("\t")
        if self.write_sc:
            fout.write(self.write_sc)
        fout.write("\t")
        fout.write(str(self.fd))
        fout.write("\t")
        fout.write(str(self.id))
        fout.write("\t")
        fout.write(self.sc)
        fout.write("\n")

    # by default, ignore socket syscalls

    def warn(self, syscall):
        sys.stderr.write(str(self.pid) + ": socket operation "
                         + syscall + " on file descriptor "
                         + str(self.fd) + "\n")

    def do_bind(self, record, nf):
        self.warn(record[1])

    def do_listen(self, record, nf):
        self.warn(record[1])

    def do_accept(self, record, nf):
        self.warn(record[1])

    def do_connect(self, record, nf):
        self.warn(record[1])

    def do_recv(self, record, nf):
        self.warn(record[1])

    def do_send(self, record, nf):
        self.warn(record[1])

# file descriptors created using pipe

class StreamTracker(Tracker):
    def __init__(self, pid, fd, name, sc):
        "create a tracker for the results of a pipe syscall"
        Tracker.__init__(self, pid, fd, sc)
        self.name = name

    def do_close(self):
        fout.write(str(self.pid))
        fout.write("\tStream\t")
        fout.write(self.name)
        Tracker.do_close(self)

class PipeTracker(Tracker):
    def __init__(self, pid, fd, sc):
        "create a tracker for the results of a pipe syscall"
        Tracker.__init__(self, pid, fd, sc)
        self.pal = -1

    def do_close(self):
        fout.write(str(self.pid))
        fout.write("\tPipe\t")
        fout.write(str(self.pal))
        Tracker.do_close(self)

class SocketPairTracker(Tracker):
    def __init__(self, pid, fd, sc):
        "create a tracker for the results of a socketpair syscall"
        Tracker.__init__(self, pid, fd, sc)
        self.pal = -1

    def do_close(self):
        fout.write(str(self.pid))
        fout.write("\tSockPr\t")
        fout.write(str(self.pal))
        Tracker.do_close(self)

    def do_recv(self, record, nf):
        self.do_read(record, nf)

    def do_send(self, record, nf):
        self.do_write(record, nf)

# file descriptors created using open

class OpenTracker(Tracker):
    def __init__(self, pid, fd, record, nf):
        "create a tracker for the results of an open syscall"
        Tracker.__init__(self, pid, fd, record[-1])
        # this code assumes len(record) > 3
        self.filename = record[2]
        self.open_sc = record[3]
        self.type = parse_type(record[4])

    def do_close(self):
        fout.write(str(self.pid))
        fout.write("\tOpen\t")
        fout.write(self.filename)
        fout.write("\t")
        if self.open_sc:
            fout.write(self.open_sc)
        fout.write("\t")
        if self.type:
            fout.write(self.type)
        Tracker.do_close(self)

isDir = re.compile(".*O_DIRECTORY.*")

def parse_type(str):
    # This function consumes a string of the form
    # O_<X> | O_<Y> | O_<Z>, where <X>, <Y>, and <Z> are
    # parsed. Currently, this function checks solely to determine if
    # an object is a file or directory, and returns the correct value as
    # the strings "FILE" or "DIRECTORY"
    if isDir.match(str):
        return "DIRECTORY"
    else:
        return "FILE"

# file descriptors created using socket

class SocketTracker(Tracker):
    def __init__(self, pid, fd, parent, record, nf):
        "create a tracker for the results of a socket syscall"
        Tracker.__init__(self, pid, fd, record[-1])
        self.parent = parent
        self.op = None
        self.addr = None

    def do_close(self):
        fout.write(str(self.pid))
        fout.write("\tSocket\t")
        if self.op:
            fout.write(self.op)
        fout.write("\t")
        fout.write(str(self.parent))
        fout.write("\t")
        if self.addr:
            fout.write(self.addr)
        Tracker.do_close(self)

    def do_bind(self, record, nf):
        if nf > 4:
            self.op = record[1]
            self.addr = record[4]

    def do_accept(self, record, nf):
        pass

    def do_connect(self, record, nf):
        if nf > 4:
            self.op = record[1]
            self.addr = record[4]

    def do_recv(self, record, nf):
        self.do_read(record, nf)

    def do_send(self, record, nf):
        self.do_write(record, nf)

if __name__ == "__main__":
    main()
