#!/usr/bin/env python
########################################################################
### FILE:	greylistd.py
### PURPOSE:	Simple grey-listing daemon.  See greylistd(8).
###		For an introduction to greylisting, see:
### 		http://projects.puremagic.com/greylisting/
### 		Essentially, this program listens for connections on a
###		UNIX domain socket, presumably from an MTA such as Exim.
###		It reads an identifier (referred to as a "triplet" in 
###		the greylisting whitepaper above), and returns a single
###		word ("white" or "grey") depending on prior knowledge
###		of said identifier.
########################################################################

from time         import time, ctime
from socket       import socket, AF_UNIX, SOCK_STREAM
from os           import remove, chmod
from os.path      import join
from pwd          import getpwnam
from grp          import getgrnam
from sys          import stdout, stderr, exit
from signal       import signal, SIGTERM, SIGHUP, SIGUSR1, SIG_IGN, SIG_DFL
from select       import select
from ConfigParser import ConfigParser



### Colors
WHITE      = "white"
GREY       = "grey"
BLACK      = "black"



### Defaults for various configuration items
datafile   = "/var/lib/greylistd/data"
conffile   = "/etc/greylistd/config"

timeouts   = { "retryMin"  : 60 * 60,
               "retryMax"  : 60 * 60 * 4,
               "expire"    : 60 * 60 * 24 * 60 }

socketconf = { "path"      : "/var/run/greylistd/socket",
               "mode"      : "0660" }


### Timestamps and count data for each triplet hash
peerdata   = { WHITE : {},
               GREY  : {},
               BLACK : {} }

### Index of elements in 'peerdata' data
IDX_LAST   = 0
IDX_FIRST  = 1
IDX_COUNT  = 2


### List statistics (how many triplets has entered each list)
stats      = { WHITE   : 0,
               GREY    : 0,
               BLACK   : 0,
               "start" : 0 }


### The UNIX domain socket handler
listener   = None
sockets	   = []



def expireKeys (now):
    for (listKey, timeKey,   timeoutKey) in (
        (GREY,    IDX_LAST,  "retryMax"),
        (WHITE,   IDX_LAST,  "expire"),
        (BLACK,   IDX_FIRST, "expire")):
        for (dataKey, data) in peerdata[listKey].items():
            if data[timeKey] + timeouts[timeoutKey] < now:
                del peerdata[listKey][dataKey]


def checkKey (key, update=True):
    now      = int(time())
    expireKeys(now)

    if not key:
        return "error: no data provided"

    elif peerdata[WHITE].has_key(key):
        status  = WHITE

    elif peerdata[BLACK].has_key(key):
        status  = BLACK

    elif peerdata[GREY].has_key(key):
        if peerdata[GREY][key][IDX_FIRST] + timeouts["retryMin"] < now:
            status  = WHITE
            if update:
                stats[WHITE] += 1
                del peerdata[GREY][key]
        else:
            status  = GREY

    else:
        status = GREY
        stats[GREY] += 1


    (lastseen, firstseen, count) = peerdata[status].get(key, (now, now, 0))

    if update:
        peerdata[status][key]  = (now, firstseen, count + 1)

    return status



def duration (secs):
    plural = ("", "s")

    if secs < 60:
        return "%d second%s"%(secs, plural[secs > 1])

    elif secs < 60 * 60:
        (mins, secs) = (secs / 60, secs % 60)
        return "%s%s%s%s%s%s%s" % (mins, " minute", plural[mins != 1],
                                   secs and " and " or "",
                                   secs or "",
                                   secs and " second" or "",
                                   plural[secs > 1] or "")

    elif (secs + 30) < 60 * 60 * 24:
        (hrs, mins) = ((secs + 30) / 3600, ((secs + 30) / 60) % 60)
        return "%s%s%s%s%s%s%s" % (hrs, " hour", plural[hrs != 1],
                                   mins and " and " or "",
                                   mins or "",
                                   mins and " minute" or "",
                                   plural[mins > 1] or "")

    else:
        (days, hrs) = ((secs + 1800) / 86400, ((secs + 1800) / 3600) % 24)
        return "%s%s%s%s%s%s%s" % (days, " day", plural[days != 1],
                                   hrs and " and " or "",
                                   hrs or "",
                                   hrs and " hour" or "",
                                   plural[hrs > 1] or "")



def getStats (key=None):
    text = []
    now  = int(time())
    expireKeys(now)

    if not key:
        starttime = stats.get("start", 0)
        if starttime:
            title = "List statistics since %s (%s ago)"%(
                ctime(starttime), duration(now - starttime))
        else:
            title = "List statistics"

        text.append(title)
        text.append("-" * len(title))
        hits    = {}
        current = {}

        for color in peerdata.keys():
            current[color] = len(peerdata[color])
            hits[color]    = 0
            for (key, data) in peerdata[color].items():
                (lastseen, firstseen, count) = data
                hits[color] += count


        for color in peerdata.keys():
            hitdigits = len(str(max(hits.values())))
            curdigits = len(str(max(current.values())))
            text.append("The %5slist currently contains %s items, "
                        "matching %s requests."%
                        (color,
                         str(current[color]).rjust(curdigits),
                         str(hits[color]).rjust(hitdigits)))


        previousGrey  = stats[GREY] - len(peerdata[GREY])
        expiredGrey   = previousGrey - stats[WHITE]

        if previousGrey:
            digits = len(str(previousGrey))

            text.append("")
            text.append("Of %s items that were initially greylisted:"%
                        str(previousGrey).rjust(digits))

            text.append(" - %s (%5.1f%%) became whitelisted"%
                        (str(stats[WHITE]).rjust(digits),
                         100.0 * stats[WHITE] / previousGrey))

            text.append(" - %s (%5.1f%%) expired from the greylist"%
                        (str(expiredGrey).rjust(digits),
                         100.0 * expiredGrey / previousGrey))


    else:
        for (color, items) in peerdata.items():
            if items.has_key(key):
                (lasttime, firsttime, count) = items[key]
                text.append("This item was %slisted %s ago, on %s"%
                            (color,
                             duration(now - firsttime),
                             ctime(firsttime)))
                text.append("It matched %s requests, last time on %s"%
                            (count, ctime(lasttime)))
                break

        else:
            text = [ "error: This item does not exist in any list" ]


    return "\n".join(text)



def runCommand (line):
    now     = int(time())
    words   = line.split()
    command = words and words[0].lower() or None
    args    = " ".join(words[1:]).lower()
    key     = hash(args)


    if command == "add":
        color = WHITE
        if (len(words) > 1) and words[1].lower() in peerdata.keys():
            color = words[1].lower()
            args  = " ".join(words[2:]).lower()
            key   = hash(args)

        if not args:
            return "error: no data provided"

        (lastseen, firstseen, count) = peerdata[color].get(key, (now, now, 0))

        for itemList in peerdata.values():
            if itemList.has_key(key):
                del itemList[key]

        peerdata[color][key] = (now, firstseen, count)
        return "Added to %slist: %s"%(color, args)

    elif command == "delete":
        if not args:
            return "error: no data provided"

        for (color, itemList) in peerdata.items():
            if itemList.has_key(key):
                del itemList[key]
                return "Removed from %slist: %s"%(color, args)
        else:
            return "error: this data does not exist"

    elif command == "check":
        return checkKey(key, False)

    elif command == "update":
        return checkKey(key, True)

    elif command == "stats":
        return getStats(key)

    elif command == "clear":
        for key in peerdata.keys():
            peerdata[key].clear()

        for key in stats.keys():
            stats[key] = 0

        stats["start"] = now

        return "all data and statistics cleared."

    elif command:
        args    = " ".join(words).lower()
        key     = hash(args)
        return checkKey(key, True)

    else:
        return "error: no data"



def loadConfig ():
    global timeouts, socketconf, datafile
    parser = ConfigParser()
    parser.read(conffile)

    for key in timeouts.keys():
        try:
            timeouts[key] = parser.getint("timeouts", key)
        except:
            pass
        

    for key in socketconf.keys():
        try:
            socketconf[key] = parser.get("socket", key)
        except:
            pass

    try:
        datafile = parser.get("data", "path")
    except:
        pass

    del parser



def loadData (signum=None, frame=None):
    parser = ConfigParser()
    parser.read(datafile)
    now    = int(time())

    for word in peerdata.keys():
        if parser.has_section(word):
            for (key, string) in parser.items(word):
                try:
                    key       = int(key)
                    data      = string.split()
                    lastseen  = int(data[0])
                    firstseen = (len(data) > 1) and int(data[1]) or now
                    count     = (len(data) > 2) and int(data[2]) or 0

                    peerdata[word][key] = (lastseen, firstseen, count)
                except:
                    pass

    if parser.has_section("statistics"):
        for key in stats.keys():
            try:
                stats[key] = int(parser.get("statistics", key))
            except:
                pass
    else:
        stats["start"] = now

    del parser
    expireKeys(time())


def saveData (signum=None, frame=None):
    expireKeys(time())

    try:
        fp = open(datafile, "w")

    except Exception, e:
        print("Cannot save to %s: %s\n"%(datafile, e[1]))
        exit(1)

    parser = ConfigParser()
    for (word, dataList) in peerdata.items():
        parser.add_section(word)
        for (key, data) in dataList.items():
            parser.set(word, "%d"%key, "%d %d %d"%data)


    if not parser.has_section("statistics"):
        parser.add_section("statistics")

    for (key, data) in stats.items():
        parser.set("statistics", key, str(data))

    parser.write(fp)
    fp.close()
    del parser



def startup ():
    global listener, sockets

    loadConfig()
    listener = socket(AF_UNIX, SOCK_STREAM)

    try:
        listener.bind(socketconf["path"])
        listener.listen(1)
    except Exception, e:
        stderr.write("Could not bind/listen to socket %s: %s\n"%
                     (socketconf["path"], e[1]))
        exit(1)


    try:
        chmod(socketconf["path"], int(socketconf["mode"], 8))
    except Exception, e:
        stderr.write("Could not change mode of socket %s: %s\n"%
                     (socketconf["path"], e[1]))
        exit(1)
                                    

    sockets = [ listener ]
    loadData()


def cleanup ():
    global listener, sockets
    remove(socketconf["path"])
    listener.close()
    listener = None
    sockets  = []
    saveData()
    

def term (signum=None, frame=None):
    print ("term")
    cleanup()
    exit(0)


def hangup (signum=None, frame=None):
    print "hangup"
    cleanup()
    startup()


signal(SIGTERM, term)
signal(SIGHUP,  hangup)
signal(SIGUSR1, saveData)

startup()

while listener:
    try:
        (inlist, outlist, errlist) = select(sockets, [], [])

    except KeyboardInterrupt:
        cleanup()

    else:
        client = inlist[0]
        if client is listener:
            (client, a) = listener.accept()
            sockets.append(client)

        else:
            try:
                line = client.recv(16384)
            except:
                pass

            else:
                status = runCommand(line)

                try:
                    client.send(status)
                    client.close()

                except:
                    pass

            sockets.remove(client)

