#!/usr/bin/env python
# -*- coding: UTF-8 -*-

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

import os

import subprocess

import re
import fnmatch

import shutil
import time

from fcntl import flock, LOCK_SH, LOCK_EX

from gettext import gettext as _

MODULES_ROOT = "/lib/modules"
PROC_MODULES = "/proc/modules"
MODULES_LOADBOOT = "/etc/modules"
MANAGER_ROOT = "/var/cache/restricted-manager"
BLACKLIST_FILE = "/etc/modprobe.d/blacklist-restricted"
MODALIAS_OVERRIDE = [
    "/usr/share/linux-restricted-modules/%s/modules.alias.override" %
        os.uname()[2],
    "/usr/share/restricted-manager/modalias_override",
]
MODALIAS_MATCH = re.compile("(?:pci|usb):")

# if set, this will be repeatedly executed while waiting for package
# installation to finish
package_install_idle_function = None

# if set to an X11 ID, this will be passed as transient parent to the package
# installer
package_install_xid = None

#-----------------------------------------------------------------------------#
# The following functions deal with finding out information about kernel
# modules themselves.
#-----------------------------------------------------------------------------#

def get_modinfo(module):
    """Return the information about the module as a dict."""
    proc = subprocess.Popen(("/sbin/modinfo", module), stdout=subprocess.PIPE)
    (stdout, stderr) = proc.communicate()
    if proc.returncode != 0:
        return None

    modinfo = {}
    for line in stdout.split("\n"):
        if ":" not in line:
            continue

        (key, value) = line.split(":", 1)
        modinfo.setdefault(key.strip(), []).append(value.strip())

    return modinfo


#-----------------------------------------------------------------------------#
# The following functions deal with keeping a cached list of restricted
# kernel modules, generated from modules.dep/modinfo
#-----------------------------------------------------------------------------#

def am_admin():
    """Check we're an admin."""
    return os.access(MANAGER_ROOT, os.W_OK)

def max_mtime(*paths):
    """Walk the paths recursively and return the time of the newest
    modification."""
    mtime = 0

    for path in paths:
        try:
            mtime = max(mtime, os.stat(path).st_mtime)
        except OSError:
            pass

        if os.path.isdir(path):
            for dirpath, dirnames, filenames in os.walk(path):
                for entry in dirnames + filenames:
                    filename = os.path.join(dirpath, entry)
                    try:
                        mtime = max(mtime, os.stat(filename).st_mtime)
                    except OSError:
                        pass

    return mtime

def load_restricted_list(force=False):
    """Load the list of restricted modules for the running kernel."""
    kernel_version = os.uname()[2]
    modules_dir = os.path.join(MODULES_ROOT, kernel_version)
    modules_dep = os.path.join(modules_dir, "modules.dep")

    modules_restricted = os.path.join(MANAGER_ROOT,
                                      kernel_version + ".restricted")

    latest_mtime = max_mtime(modules_dep, *MODALIAS_OVERRIDE)

    try:
        restricted_mtime = os.stat(modules_restricted).st_mtime
    except OSError:
        return generate_restricted_list(modules_dep, modules_restricted, force)

    if force or latest_mtime > restricted_mtime:
        return generate_restricted_list(modules_dep, modules_restricted, force)

    restricted = {}
    restricted_file = open(modules_restricted, "r")
    try:
        flock(restricted_file.fileno(), LOCK_SH)
        for line in restricted_file:
            words = line.split()
            module = words.pop(0)
            alias_patterns = words
            restricted[module] = alias_patterns
    finally:
        restricted_file.close()

    return restricted

def generate_restricted_list(modules_dep, modules_restricted, force=False):
    """Generate the list of restricted modules for the running kernel."""
    restricted = {}

    # try to get a *.ko file list from the main kernel package to avoid testing
    # known-free drivers
    dpkg = subprocess.Popen(["dpkg", "-L", "linux-image-" + os.uname()[2]],
        stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out = dpkg.communicate()[0]
    if dpkg.returncode == 0:
        free_files = set([os.path.basename(f) for f in out.splitlines()])
    else:
        free_files = set()

    dep_file = open(modules_dep, "r")
    try:
        for line in dep_file:
            try:
                module_filename = line[:line.index(":")].strip()
            except ValueError:
                continue

            module = os.path.basename(module_filename)
            if module in free_files:
                continue
            if module.endswith(".ko"):
                module = module[:-3]
            module.replace("-", "_")

            modinfo = get_modinfo(module_filename)
            if modinfo is None:
                continue
            license = modinfo.get("license", ["unknown"])[0]
            gplok = is_license_gpl_compatible(license)

            if not gplok:
                restricted.setdefault(module, [])

                for alias_pattern in modinfo.get("alias", []):
                    if MODALIAS_MATCH.match(alias_pattern):
                        restricted[module].append(alias_pattern)

    finally:
        dep_file.close()

    # Override patterns with modalias.override.
    restricted.update(parse_modules_aliases(*MODALIAS_OVERRIDE))

    try:
        save_restricted_list(modules_restricted, restricted)
    except IOError:
        # It's a bit of a bugger if we can't write the cache, but not critical
        # unless that's what we were trying to do in the first place
        if force:
            raise

    return restricted

def parse_modules_alias(*filenames):
    """Parse modalias patterns from files."""
    modules = {}

    for filename in filenames:
        try:
            alias_file = open(filename, "r")
            for line in alias_file:
                words = line.split()
                if len(words) == 3:
                    command, pattern, module = words

                    if command == "alias" and MODALIAS_MATCH.match(pattern):
                        modules.setdefault(module, []).append(pattern)

        except IOError:
            # Ignore files that couldn't be read.
            pass

    return modules

def parse_modules_aliases(*paths):
    """Parse modalias patterns from each file under path."""
    files = []

    for path in paths:
        for dirpath, dirnames, filenames in os.walk(path):
            for filename in filenames:
                files.append(os.path.join(dirpath, filename))

    return parse_modules_alias(*files)

def save_restricted_list(modules_restricted, restricted):
    """Write the restricted list to the disk."""
    os.umask(002)
    restricted_file = open(modules_restricted, "w")
    try:
        flock(restricted_file.fileno(), LOCK_EX)
        for module, alias_patterns in restricted.items():
            print >>restricted_file, module + " " + " ".join(alias_patterns)
        restricted_file.close()
    except (IOError, OSError):
        restricted_file.close()
        # never keep broken cache files around
        os.unlink(modules_restricted)
        raise

def is_license_gpl_compatible(license):
    """Return whether a given license is GPL compatible.
    (keep this code in sync with the kernel function of the same name)
    """
    if license == "GPL":
        return True
    elif license == "GPL v2":
        return True
    elif license == "GPL and additional rights":
        return True
    elif license == "Dual BSD/GPL":
        return True
    elif license == "Dual MIT/GPL":
        return True
    elif license == "Dual MPL/GPL":
        return True
    else:
        return False

def notify_reboot_required():
    """Trigger the "Reboot required" notification."""

    try:
        subprocess.call(["/usr/share/update-notifier/notify-reboot-required"])
        open("/var/lib/update-notifier/dpkg-run-stamp", "w").close()
    except (IOError, OSError):
        pass

#-----------------------------------------------------------------------------#
# The following function compares the list of restricted modules to the
# connected hardware and returns the set of modules that match.
#-----------------------------------------------------------------------------#

def connected_hardware(restricted):
    """Return the set of restricted modules matching to the connected hardware.
    """
    devices = set()

    modules_with_patterns = {}
    for module, patterns in restricted.items():
        if patterns:
            modules_with_patterns[module] = patterns
        else:
            # No alias patterns known, so no way to check for matching
            # hardware. Just add it to the list of "connected" devices just in
            # case it actually happens to be connected.
            devices.add(module)

    connected_aliases = []
    for dirpath, dirnames, filenames in os.walk("/sys/devices"):
        if "modalias" in filenames:
            modalias = open(os.path.join(dirpath, "modalias")).read().strip()
            connected_aliases.append(modalias)

    for module, patterns in modules_with_patterns.items():
        for pattern in patterns:
            if fnmatch.filter(connected_aliases, pattern):
                # Matching hardware is connected.
                devices.add(module)

    return devices


#-----------------------------------------------------------------------------#
# The following functions deal with keeping a cached list of used modules.
#-----------------------------------------------------------------------------#

def load_used_list():
    """Load the list of restricted modules that we've used."""
    modules_used = os.path.join(MANAGER_ROOT, "used")

    try:
        used_file = open(modules_used, "r")
    except IOError:
        return []

    used = []
    try:
        flock(used_file.fileno(), LOCK_SH)
        for line in used_file:
            used.append(line.strip())
    finally:
        used_file.close()

    return used

def save_used_list(used):
    """Write the used modules list to the disk."""
    modules_used = os.path.join(MANAGER_ROOT, "used")

    os.umask(002)
    used_file = open(modules_used, "w")
    try:
        flock(used_file.fileno(), LOCK_EX)
        for module in used:
            print >>used_file, module
    finally:
        used_file.close()


#-----------------------------------------------------------------------------#
# The following functions allow us to place code in Python modules alongside
# this one that overrides kernel module information and provides special
# handling
#-----------------------------------------------------------------------------#

def get_specials():
    """Load the list of special modules."""
    specials_dir = os.path.dirname(__file__)
    names = [ os.path.basename(path)[:-3] for path in os.listdir(specials_dir)
              if path.endswith(".py") ]

    specials = {}
    for name in names:
        rm = __import__("RestrictedManager.%s" % name)
        pymod = getattr(rm, name)

        for attr in dir(pymod):
            obj = getattr(pymod, attr)
            if hasattr(obj, "is_handler"):
                h = obj(obj.name)
                if h._modinfo:
                    specials[obj.name] = obj(obj.name)

    return specials

def get_handler(specials, module):
    """Return the handler for the module."""
    if module in specials:
        h = specials[module]
    else:
        h = DefaultHandler(module)

    if h._modinfo:
        return h
    else:
        return None

def get_handlers():
    """Return a dictionary of restricted driver handlers."""
    restricted = load_restricted_list()
    hardware = connected_hardware(restricted)
    specials = get_specials()

    handlers = {}
    for module in hardware:
        handler = get_handler(specials, module)
        if handler:
            handlers[handler.name] = handler
    for module, handler in specials.items():
        if module not in restricted:
            handlers[module] = handler

    return handlers


#-----------------------------------------------------------------------------#
# Here we declare the default handler for a restricted module that we find
# in the kernel.
#-----------------------------------------------------------------------------#

class DefaultHandler(object):
    def __init__(self, module):
        self._module = module
        self._modinfo = get_modinfo(module)
        self._changed = False

    # To define overrides for a module handling, or even define a completely
    # new restricted driver unrelated to a kernel module, create a file in
    # the RestrictedManager module that contains one or more classes.
    #
    # Either derive from this class, or ensure you have the same properties
    # and object-level functions.
    #
    # Your class will need to contain the is_handler member to be used.
    #
    # from RestrictedManager.core import DefaultHandler
    # class EvilDriver(DefaultHandler):
    #     is_handler = True

    # Override the name property (usually just as a variable) to set the name
    # of the kernel module that your new handler deals with -- set it to some
    # useful string if it's not a true kernel module.
    #
    #     name = "evil"
    @property
    def name(self):
        """Name of the kernel module itself."""
        return self._module

    # Override the description property if the module's own string isn't that
    # useful; all non-kernel modules will need to set this.
    #
    #     description = "Support for the evil hardware family"
    @property
    def description(self):
        """One-line description of the module (human name)."""
        return self._modinfo.get("description", [self.name])[0]

    # Override the rationale property to provide a paragraph or two of
    # rationale why a user would want this module enabled.
    #
    #    rationale = "Without this module, the blue light is orange"
    @property
    def rationale(self):
        """Rationale as to why this driver might be enabled."""
        return _("This driver is necessary to support the hardware, there "
                "is no free/open alternative.\n\n"
                "If this driver is not enabled, the hardware will not "
                "function in Ubuntu.")

    def is_changed(self):
        """Returns True if the module has been enabled/disabled at least
        once."""

        return self._changed

    # Override the is_loaded function if finding out whether the driver is
    # in use or not is not as simple as checking /proc/modules
    def is_loaded(self):
        """Returns True if the module is loaded."""
        return self.name in self._modules

    # Override the is_enabled function if finding out whether the driver is
    # enabled or not is not as simple as checking
    # /etc/modprobe.d/restricted-blacklist
    def is_enabled(self):
        """Returns True if the module is enabled."""
        return self.name not in self._blacklist

    # Override the enable function if you need to do more than remove the
    # driver from /etc/modprobe.d/restricted-blacklist to enable it
    def enable(self):
        """Enable the module."""

        self._changed = True
        if self.name in self._blacklist:
            self._blacklist.remove(self.name)
            self.save_blacklist()

    # Override the disable function if you need to do more than add the
    # driver from /etc/modprobe.d/restricted-blacklist to disable it
    def disable(self):
        """Disable the module."""

        self._changed = True
        if self.name not in self._blacklist:
            self._blacklist.append(self.name)
            self.save_blacklist()

    def can_change(self):
        """Check whether we can actually modify settings of this handler.

        This might not be the case if e. g. the user manually modified a
        configuration file. Return an explanatory text if settings can not be
        changed, or None if changing is ok."""

        return None

    # The following classmethods are part of the default handler, you don't
    # need to override or worry about them (you may want to call them though)
    @classmethod
    def load_module_list(klass):
        """Load the list of modules in the kernel."""
        klass._modules = []

        proc_modules = open(PROC_MODULES, "r")
        try:
            for line in proc_modules:
                try:
                    line = line[:line.index(" ")]
                except ValueError:
                    pass

                klass._modules.append(line.strip())
        finally:
            proc_modules.close()

    @classmethod
    def load_blacklist(klass):
        """Load the list of blacklisted modules."""
        klass._blacklist = []

        try:
            blacklist_file = open(BLACKLIST_FILE, "r")
        except IOError:
            return

        try:
            flock(blacklist_file.fileno(), LOCK_SH)
            for line in blacklist_file:
                try:
                    line = line[:line.index("#")]
                except ValueError:
                    pass

                if not line.startswith("blacklist"):
                    continue

                module = line[len("blacklist"):].strip()
                if len(module):
                    klass._blacklist.append(module)
        finally:
            blacklist_file.close()

    @classmethod
    def save_blacklist(klass):
        """Save the list of blacklisted modules."""
        if not len(klass._blacklist):
            try:
                os.unlink(BLACKLIST_FILE)
            except OSError:
                pass
            return

        os.umask(022)
        blacklist_file = open(BLACKLIST_FILE, "w")
        try:
            flock(blacklist_file.fileno(), LOCK_EX)
            print >>blacklist_file, "# This file is used to disable restricted drivers"
            for module in sorted(klass._blacklist):
                print >>blacklist_file, "blacklist %s" % module
        finally:
            blacklist_file.close()

    # Create a backup of a config file, using the format
    # original_path.restricted-manager.YYYYMMDD-HHMMSS, or if that already
    # exists, original_path.restricted-manager.YYYYMMDD-HHMMSS.N where N is the
    # first available integer starting from zero.
    #
    # The original file is locked exclusively during the operation to avoid a
    # race condition.
    @classmethod
    def backup_conffile(klass, filename):
        """Create a backup of a config file before overwriting it."""
        if os.path.exists(filename):
            f = open(filename, "r")
            flock(f, LOCK_EX)

            backup_prefix = "%s.restricted-manager.%s" % (filename,
                time.strftime("%Y%m%d-%H%M%S"))
            backup = backup_prefix

            i = 0
            while os.path.exists(backup):
                backup = backup_prefix + "." + str(i)
                i += 1

            shutil.copy2(filename, backup)

    @classmethod
    def package_installed(klass, package):
        """Check whether the given packge is installed."""

        dpkg = subprocess.Popen(["dpkg-query", "-W", "-f${Status}", package],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out = dpkg.communicate()[0]
        return dpkg.returncode == 0 and out.split()[-1] == "installed"

    @classmethod
    def install_package(klass, package):
        """Install given package through synaptic and return synaptic's exit
        status.
        
        Might throw OSError if synaptic is not available."""

        argv = ["/usr/sbin/synaptic", "--set-selections", "--non-interactive",
            "--hide-main-window"]
        if package_install_xid:
            argv += ["--parent-window-id", str(package_install_xid)]

        synaptic = subprocess.Popen(argv, stdin=subprocess.PIPE)
        if package_install_idle_function:
            synaptic.stdin.write(package + " install")
            synaptic.stdin.close()
            while synaptic.poll() is None:
                time.sleep(0.1)
                package_install_idle_function()
        else:
            synaptic.communicate(package + " install")
        return synaptic.returncode

    @classmethod
    def remove_package(klass, package):
        """Remove given package through synaptic and return synaptic's exit
        status.
        
        Might throw OSError if synaptic is not available."""

        argv = ["/usr/sbin/synaptic", "--set-selections", "--non-interactive",
            "--hide-main-window"]
        if package_install_xid:
            argv += ["--parent-window-id", str(package_install_xid)]

        synaptic = subprocess.Popen(argv, stdin=subprocess.PIPE)
        if package_install_idle_function:
            synaptic.stdin.write(package + " deinstall")
            synaptic.stdin.close()
            while synaptic.poll() is None:
                time.sleep(0.1)
                package_install_idle_function()
        else:
            synaptic.communicate(package + " deinstall")
        return synaptic.returncode

    @classmethod
    def enable_etcmodules(klass, module):
        """Add given module to /etc/modules.

        Does nothing if the module is already present."""

        try:
            if module in open(MODULES_LOADBOOT).read().splitlines():
                return
        except IOError:
            pass

        print >> open(MODULES_LOADBOOT, "a"), module

    @classmethod
    def disable_etcmodules(klass, module):
        """Remove given module from /etc/modules.

        Does nothing if the module is not present."""

        try:
            mods = open(MODULES_LOADBOOT).read().splitlines()
        except IOError:
            return

        if module not in mods:
            return

        mods.remove(module)
        print >> open(MODULES_LOADBOOT, "w"), "\n".join(mods)

DefaultHandler.load_module_list()
DefaultHandler.load_blacklist()
