#!/usr/bin/env python

"""copies modules into place, using the environment variables MODULEDIR and INITRDDIR"""

# copyright 2004 vagrant@freegeek.org, distributed under the terms of the
# GNU General Public License version 2 or any later version.

# $Id: copy_modules,v 1.12 2004/04/13 22:29:44 vagrant Exp $

import os, shutil, sys

# which files to pull modules from
module_files=['modules','modules.extra','network_cards']
# directory for configuration

config_dir='/etc/lessdisks/mkinitrd'
if not os.path.isdir(config_dir):
  sys.exit(1)

module_dir=os.environ.get('MODULEDIR')
if module_dir:
  if os.path.exists(module_dir):
    pass
  else:
    sys.stderr.write('no module dir for %s\n' % module_dir)
else:
  sys.stderr.write('no module dir defined!\n')

kernel_version=os.path.basename(module_dir)
if not kernel_version:
  sys.exit(1)

tempdir=os.environ.get('INITRDDIR')

if not tempdir:
  sys.exit(1)

## begin function declarations ##

def listLines(filenames, config_dir):
  """generates a list of lines from several files"""
  y=[]
  for filename in filenames:
    filename=os.path.join(config_dir, filename)
    if os.path.exists(filename):
      x=open(filename, 'r')
      y.extend(x.readlines())
      x.close()
  return y

def getRequestedModules(config_dir='/etc/mkinitrd', module_files=['modules'], module_list=[]):
  y=listLines(module_files, config_dir)
  for line in y:
    if line.startswith('#') or line.startswith('\n'):
      pass
    else:
      module_list.append(line.split()[0])
  return module_list

def getModuleDepList(kernel_version):
  if kernel_version:
    kernel=kernel_version
  else:
    q=os.popen('uname -r')
    kernel=q.readline().split('\n')[0]
    q.close()

  x=open('/lib/modules/'+kernel+'/modules.dep')

  allmodules={}

  lastmodule=None
  for line in x.readlines():
    line=line.split()
    if len(line) > 0:
      if not lastmodule:
        #module=line[0].split(r'/')[-1].rstrip(r'.o:')
        module=line[0].split(r'/')[-1].split(r'.o')[0]
        allmodules[module]=line
      elif allmodules.has_key(module):
        #print 'hey', allmodules[module], line
        allmodules[module]=allmodules[module]+line
      else:
        print 'what to do with', line
      lastmodule=module
    else:
      lastmodule=None

  x.close()
  return allmodules

def getModules(requested_modules, kernel_version, allmodules=None):
  if not allmodules:
    allmodules=getModuleDepList(kernel_version)

  modules=[]

  for module in requested_modules:
    module=module.split('.o')[0]
    if allmodules.has_key(module):
      for w in allmodules[module]:
        w=w.split(':')[0]
        if w.endswith(r'.o'):
          w=w.split('/')[-1].split('.o')[0]
          if modules.count(w) < 1:
            modules.append(w)
    else:
      print 'no module named:', module
  modules.sort()
  return modules

def makeModuleDirs(module_paths, tempdir='/tmp/modules'):
  dirs_to_make=[]
  for module in module_paths:
    depth=len(module.split('/'))
    basedir=module.split('/')[depth-1]
    dirs_to_make.append(module.split(basedir)[0])
  for dir in dirs_to_make:
    dir=tempdir+dir
    if not os.path.isdir(dir):
      if os.path.exists(dir):
        print 'WARNING: %s is not a directory...?' % dir
      else:
        os.makedirs(dir)

def getModuleList(modules, kernel_version, allmodules=None, module_paths=None):
  module_paths=[]
  if not allmodules:
    allmodules=getModuleDepList(kernel_version)
  for module in modules:
    if allmodules.has_key(module):
      for w in allmodules[module]:
        w=w.split(':')[0]
        if w.endswith(r'.o') and module_paths.count(w) < 1:
          module_paths.append(w)
  return module_paths

def copyModules(module_paths, tempdir='/tmp/modules'):
  makeModuleDirs(module_paths, tempdir)
  for module in module_paths:
    #shutil.copyfile(module, os.path.join(tempdir, module))
    shutil.copyfile(module, tempdir+'/'+module)

def getAllModuleDeps(requested_modules, kernel_version, n=30, i=0, last_modules=None):
  # loop through no more than n times
  for i in xrange(n):
    modules=getModules(requested_modules, kernel_version)
    if last_modules == modules:
      # if no change, we're done resolving dependencies..
      return modules
    last_modules=modules

def makeModuleDepFile(kernel_version, tempdir):
  os.system('/sbin/depmod -a -b %s -F /boot/System.map-%s %s' % (tempdir, kernel_version, kernel_version))
  # remove unnecessary dependency files if they're there...
  for file in ('modules.generic_string','modules.ieee1394map','modules.isapnpmap','modules.parportmap','modules.pcimap','modules.pnpbiosmap','modules.usbmap'):
    file_path=tempdir+'/lib/modules/'+kernel_version+'/'+file
    if os.path.exists(file_path):
      os.unlink(tempdir+'/lib/modules/'+kernel_version+'/'+file)

## end function declarations ##
requested_modules=getRequestedModules(config_dir, module_files)
modules=getAllModuleDeps(requested_modules, kernel_version)
module_paths=getModuleList(modules, kernel_version)
copyModules(module_paths, tempdir)
makeModuleDepFile(kernel_version, tempdir)
