#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import os
import thread
import threading

from winswitch.util.common import alphanumfile
from winswitch.util.simple_logger import Logger
from winswitch.util.process_util import exec_nopipe


class FS_Client_Helper:
	"""
	Superclass for all filesystem client helpers.
	Callers should call mount and umount.
	This will call do_mount/do_umount in a separate thread with a lock held.
	Sub-classes must implement do_mount and may implement do_umount.
	(if do_mount returns True then umount will be added to close_callbacks)
	"""

	def __init__ (self, fs_type, mount_location, mount_command, umount_command):
		Logger(self)
		self.type = fs_type
		self.mount_location = mount_location
		self.mount_command = mount_command
		self.umount_command = umount_command
		self.lock = threading.Lock()

	def mount(self, user, mp, close_callbacks):
		thread.start_new_thread(self.threaded_mount, (user, mp, close_callbacks))

	def threaded_mount(self, user, mp, close_callbacks):
		try:
			self.lock.acquire()
			try:
				if self.do_mount(user, mp):
					close_callbacks.append(lambda : self.umount(user, mp))
					close_callbacks.append(lambda : os.rmdir(mp.mount_point))
				else:
					self.serror("failed to mount", user, mp, close_callbacks)
			except Exception, e:
				self.serr(None, e, user, mp, close_callbacks)
		finally:
			self.lock.release()

	def do_mount(self, user, mp):
		raise Exception("not implemented by %s" % self)

	def umount(self, user, mp):
		thread.start_new_thread(self.threaded_umount, (user,mp))

	def threaded_umount(self, user, mp):
		try:
			self.lock.acquire()
			try:
				self.do_umount(user, mp)
			except Exception, e:
				self.serr(None, e, user, mp)
		finally:
			self.lock.release()

	def do_umount(self, user, mp):
		cmd = self.umount_command.split(" ")
		cmd.append(mp.mount_point)
		exec_nopipe(cmd, wait=True)

	def get_local_mount_point(self, host, path):
		mnt_path = os.path.expanduser(self.mount_location)
		for part in [host, path]:
			sanitized = alphanumfile(part)
			mnt_path = self.build_local_mount_point(mnt_path, sanitized)
		return mnt_path

	def build_local_mount_point(self, path, part):
		mnt_path = os.path.join(path, part)
		if not os.path.exists(mnt_path):
			self.sdebug("creating %s" % mnt_path, path, part)
			os.makedirs(mnt_path)
		return mnt_path
