from LTSPAgent.plugin import Plugin, auto_update
from storm.locals import *
import datetime, paramiko, re, sys, logging, os, codecs, inspect, socket, subprocess, uuid
from hashlib import sha1

class Account(object):
    __storm_table__ = "account"
    id = Int(primary=True)
    server = Unicode()
    username = Unicode()
    password = Unicode()
    fullname = Unicode()
    source = Unicode()
    session = Unicode()
    locale = Unicode()
    token = Unicode()
    enabled = Bool()
    created = DateTime()

class weblive(Plugin):
    # Variables
    unavailable_servers=[]
    disabled_servers=[]
    locales={}
    locales_description={}
    packages={}

    def init_plugin(self):
        """Prepare logging, database and locales"""

        # Disable paramiko logging
        logging.getLogger('paramiko.transport').setLevel(logging.CRITICAL)

        # Initialize the shared DB connection
        if not "postgres://" in self.get_config_path(self.config,"general","database"):
            self.LOGGER.critical("WebLive only supports PostgreSQL")
            sys.exit(1)

        self.db=create_database(self.get_config_path(self.config,"general","database"))

        # Get the store
        store=Store(self.db)

        # Create the TABLE if it doesn't exist
        try:
            store.execute("CREATE TABLE account (id SERIAL, server VARCHAR, username VARCHAR, password VARCHAR, fullname VARCHAR, source VARCHAR, session VARCHAR, locale VARCHAR, token VARCHAR, enabled BOOLEAN, created timestamp without time zone)")
            store.commit()
        except:
            store.rollback()

        # Extend the schema for TABLE with missing 'source' or 'session' field
        try:
            store.execute("ALTER TABLE account ADD COLUMN source VARCHAR")
            store.execute("ALTER TABLE account ADD COLUMN session VARCHAR")
            store.commit()
        except:
            store.rollback()

        # Extend the schema for TABLE with missing 'password' field
        try:
            store.execute("ALTER TABLE account ADD COLUMN password VARCHAR")
            store.commit()
        except:
            store.rollback()

        # Extend the schema for TABLE with missing 'locale' field
        try:
            store.execute("ALTER TABLE account ADD COLUMN locale VARCHAR")
            store.commit()
        except:
            store.rollback()

        # Extend the schema for TABLE with missing 'token' field
        try:
            store.execute("ALTER TABLE account ADD COLUMN token VARCHAR")
            store.commit()
        except:
            store.rollback()

        # Load locale nice names
        for path in ("/usr/share/ltsp-agent/plugins/weblive/locales.list","plugins/weblive/locales.list"):
            if os.path.exists(path):
                for line in codecs.open(path,'r','utf-8').readlines():
                    line=line.strip()
                    self.locales_description[line.split()[0]]=" ".join(line.split()[1:])
                break

        # Call parent function (start the threads)
        Plugin.init_plugin(self)

    def get_ssh(self,serverid):
        """Connect to an ssh server and return the connection"""

        # Connect to server
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(
            self.get_config_path(self.config,"servers",serverid,"server"),
            port=int(self.get_config_path(self.config,"servers",serverid,"port")),
            username="root",
            password=self.get_config_path(self.config,"servers",serverid,"password"),
            allow_agent=False,
            look_for_keys=False,
            timeout=10
        )
        return ssh

    def call_hook(self,name,params):
        """Call a hook in plugins/weblive/hooks/ called 'name' and pass it 'params'"""

        hook=None
        for path in ("/usr/share/ltsp-agent/plugins/weblive/hooks","plugins/weblive/hooks"):
            if os.path.exists("%s/%s" % (path,name)):
                hook="%s/%s" % (path,name)

        if hook == None:
            return False

        retval=subprocess.Popen([hook]+params,stdout=subprocess.PIPE,stderr=subprocess.PIPE).wait()
        return retval

    ## Exported functions

    def create_user(self,serverid,username,fullname,password,source,session,locale):
        """Create a new user on the specified server"""

        # Get the store
        store=Store(self.db)

        if serverid not in self.get_config_path(self.config,"servers") or serverid in self.unavailable_servers or serverid in self.disabled_servers:
            # Invalid server
            return 7

        if not re.match("^[A-Za-z0-9 ]*$",fullname):
            # Invalid fullname, must only contain alphanumeric characters and spaces
            return 3

        if not re.match("^[a-z]*$",username) or username in self.get_config_path(self.config,"general","username_blacklist"):
            # Invalid login, must only contain lowercase letters
            return 4

        if not re.match("^[A-Za-z0-9]*$",password):
            # Invalid password, must contain only alphanumeric characters
            return 5

        if store.find(Account, server=unicode(serverid), enabled=True).count() >= int(self.get_config_path(self.config,"servers",serverid,"userlimit")):
            # Reached user limit, return false
            return 1

        # Accept invalid locales but mark them as None
        if locale not in [loc[0] for loc in self.list_locales(serverid)]:
            locale = None

        # Generate token
        token = uuid.uuid4()

        if store.find(Account, server=unicode(serverid), enabled=True, username=unicode(username)).count() == 1:
            user = store.find(Account, server=unicode(serverid), enabled=True, username=unicode(username))[0]

            if user.password == sha1(password).hexdigest():
                # Account already exists, renew it
                user.created = datetime.datetime.now()

                # If session is different, append to existing session
                if session not in user.session.split(","):
                    user.session += ",%s" % session

                # If locale is different, append to existing locale
                if locale not in user.locale.split(","):
                    user.locale += ",%s" % locale

                store.commit()
                return [self.get_config_path(self.config,"servers",serverid,"server"),int(self.get_config_path(self.config,"servers",serverid,"port"))]
            else:
                # Different user already exists
                return 2

        # Add user to the database
        user=Account()
        user.username = unicode(username)
        user.fullname = unicode(fullname)
        user.server = unicode(serverid)
        user.source = unicode(source)
        user.session = unicode(session)
        user.password = unicode(sha1(password).hexdigest())
        user.locale = unicode(locale)
        user.token = unicode(token)
        user.enabled = True
        user.created = datetime.datetime.now()

        try:
            # Connect to server
            ssh = self.get_ssh(serverid)

            # Create the account
            stdin, stdout, stderr = ssh.exec_command("adduser --quiet --gecos '%s' %s" % (fullname,username))
            stdin.write('%s\n' % password)
            stdin.flush()
            stdin.write('%s\n' % password)
            stdin.flush()
            stdout.read()

            # Create the container when autoinstall is enabled
            if self.get_config_path(self.config,"servers",serverid,"autoinstall") != 'false':
                ssh.exec_command("weblive-arkose adduser %s" % username)

            #FIXME: Should only do it when on an x2go capable server
            ssh.exec_command("adduser %s x2gousers" % username)

            # Set the locale if we know it
            if locale:
                ssh.exec_command("echo %s > /home/%s/.weblive_locale" % (locale,username))[1].channel.recv_exit_status()
                ssh.exec_command("chown %s.%s /home/%s/.weblive_locale" % (username,username,username))

            # Set the communication token
            ssh.exec_command("echo %s > /home/%s/.weblive_token" % (token,username))[1].channel.recv_exit_status()
            ssh.exec_command("chown %s.%s /home/%s/.weblive_token" % (username,username,username))

            # Set the home directory permissions
            ssh.exec_command("chmod 700 /home/%s" % username)

            # Disconnect from server
            ssh.close()

            # Save the new user
            store.add(user)
            store.commit()

            # Everything worked
            self.LOGGER.info("Created user '%s' on '%s' for '%s'." % (username,serverid,session))
            return [self.get_config_path(self.config,"servers",serverid,"server"),int(self.get_config_path(self.config,"servers",serverid,"port"))]

        except:
            if serverid not in self.unavailable_servers:
                self.LOGGER.error("Server '%s' is unavailale." % serverid)
                self.unavailable_servers.append(serverid)
                self.call_hook("server-updown",[serverid,"DOWN"])

            return False

    def delete_user(self,serverid,username):
        """Delete a user on the specified server"""

        if serverid not in self.get_config_path(self.config,"servers") or serverid in self.unavailable_servers or serverid in self.disabled_servers:
            # Invalid server
            return False

        self.delete_users({serverid:[username]})
        return True

    def delete_users(self,users):
        """Delete a batch of users from various servers"""

        # Get the store
        store=Store(self.db)

        for serverid in users:
            if serverid in self.unavailable_servers or serverid in self.disabled_servers:
                # Server is unavailable, skip it for now
                continue

            for username in users[serverid]:
                user=store.find(Account, server=unicode(serverid), enabled=True, username=unicode(username))

                if user.count() == 0:
                    # User no longer exists or it was already expired, remove it from our list
                    users[serverid].remove(username)
                else:
                    # Disable the account in the database
                    for user_object in user:
                        user_object.enabled = False
                    store.commit()

            if len(users[serverid]) == 0:
                continue

            try:
                # Connect to server
                ssh = self.get_ssh(serverid)

                # Disconnect the users
                for username in users[serverid]:
                    ssh.exec_command("nxserver --force-terminate %s" % username)[1].channel.recv_exit_status()
                    ssh.exec_command("pkill -9 -u %s" % username)[1].channel.recv_exit_status()
                    ssh.exec_command("umount /home/%s/.gvfs" % username)[1].channel.recv_exit_status()
                    ssh.exec_command("find /dev/shm /var /tmp -user \"%s\" -delete" % username)[1].channel.recv_exit_status()
                    ssh.exec_command("deluser --force --remove-home --quiet %s" % username)[1].channel.recv_exit_status()
                    if self.get_config_path(self.config,"servers",serverid,"autoinstall") != 'false':
                        ssh.exec_command("weblive-arkose deluser %s" % username)[1].channel.recv_exit_status()

                # Kill remaining nxserver processes
                ssh.exec_command("while pkill -9 -P 1 -u nx; do : ; done")[1].channel.recv_exit_status()

                # Disconnect from server
                ssh.close()

                # Everything worked
                for username in users[serverid]:
                    self.LOGGER.info("Deleted user '%s' from '%s'." % (username,serverid))
            except:
                if serverid not in self.unavailable_servers:
                    self.LOGGER.error("Server '%s' is unavailale." % serverid)
                    self.unavailable_servers.append(serverid)
                    self.call_hook("server-updown",[serverid,"DOWN"])

    def list_everything(self):
        """List all servers including all the packages and locales"""

        everything={}
        servers=self.list_servers()
        for serverid in servers:
            server=servers[serverid]
            server['locales']=self.list_locales(serverid)
            server['packages']=self.list_packages(serverid)
            everything[serverid]=server

        return everything

    def list_locales(self,serverid):
        """List all the locales available on a server"""

        if serverid not in self.locales:
            return []

        return self.locales[serverid]

    def list_package_blacklist(self):
        """List all blacklisted packages"""

        return self.get_config_path(self.config,"general","package_blacklist")

    def list_packages(self,serverid):
        """List all the packages available on a server"""

        if serverid not in self.packages:
            return []

        packages=[]
        for package in self.packages[serverid]:
            if package[0] in self.get_config_path(self.config,"general","package_blacklist"):
                continue

            if package[2] == True and self.get_config_path(self.config,"servers",serverid,"autoinstall") == 'false':
                continue

            packages.append(package)

        return packages

    def list_servers(self):
        """List all the servers"""

        # Get the store
        store=Store(self.db)

        servers = {}
        for serverid in self.get_config_path(self.config,"servers"):
            if serverid in self.unavailable_servers:
                continue

            if serverid in self.disabled_servers:
                continue

            server_dict={}
            server_dict['title']=self.get_config_path(self.config,"servers",serverid,"title")
            server_dict['description']=self.get_config_path(self.config,"servers",serverid,"description")
            server_dict['users']=store.find(Account, server=unicode(serverid), enabled=True).count()
            server_dict['userlimit']=int(self.get_config_path(self.config,"servers",serverid,"userlimit"))
            server_dict['timelimit']=int(self.get_config_path(self.config,"servers",serverid,"timelimit"))
            server_dict['autoinstall']=(self.get_config_path(self.config,"servers",serverid,"autoinstall") != 'false')
            servers[serverid]=server_dict

        return servers

    def list_users(self,serverid,enabled = True):
        """List all the users for a specified server"""

        # Get the store
        store=Store(self.db)

        if enabled:
            users=store.find(Account, server=unicode(serverid), enabled=True)
        else:
            users=store.find(Account, server=unicode(serverid))
        userlist=[]
        for user in users:
            userlist.append({
                'username':str(user.username),
                'fullname':str(user.fullname),
                'source':str(user.source),
                'session':str(user.session),
                'locale':str(user.locale),
                'created':str(user.created)})
        return userlist

    @auto_update(3600)
    def update_locales(self):
        """Update the list of locales for all servers"""

        for serverid in self.get_config_path(self.config,"servers"):
            if serverid in self.unavailable_servers:
                continue

            if serverid in self.disabled_servers:
                continue

            try:
                # Connect to the server
                ssh = self.get_ssh(serverid)

                # Get supported locales
                locales=[]
                for line in ssh.exec_command("locale -a")[1].readlines():
                    line=line.strip()
                    if line.endswith('utf8'):
                        if line.split('.')[0] in self.locales_description:
                            nice=self.locales_description[line.split('.')[0]]
                        else:
                            nice=line
                        locales.append((line,nice))

                # Disconnect from server
                ssh.close()

                if len(locales) > 0:
                    self.locales[serverid]=locales
            except:
                if serverid not in self.unavailable_servers:
                    self.LOGGER.error("Server '%s' is unavailale." % serverid)
                    self.unavailable_servers.append(serverid)
                    self.call_hook("server-updown",[serverid,"DOWN"])

    @auto_update(3600)
    def update_packages(self):
        """Update the list of packages for all servers"""

        for serverid in self.get_config_path(self.config,"servers"):
            if serverid in self.unavailable_servers:
                continue

            if serverid in self.disabled_servers:
                continue

            try:
                # Connect to the server
                ssh = self.get_ssh(serverid)

                # Get package list
                packages=[]
                use_dpkg=True
                try:
                    # Try using the package list if it exists
                    for line in ssh.exec_command("cat /var/cache/weblive.pkglist")[1].readlines():
                        line=line.strip()
                        fields=line.split(';')
                        packages.append([fields[0],fields[1],fields[2] == "True"])

                    if len(packages) > 0:
                        use_dpkg=False
                except:
                    pass

                if use_dpkg == True:
                    # Use good old dpkg
                    for line in ssh.exec_command("dpkg -l")[1].readlines():
                        line=line.strip()
                        if line.startswith('ii'):
                            packages.append(line.split()[1:3]+[False])

                # Disconnect from server
                ssh.close()

                if len(packages) > 0:
                    self.packages[serverid]=packages
            except:
                if serverid not in self.unavailable_servers:
                    self.LOGGER.error("Server '%s' is unavailale." % serverid)
                    self.unavailable_servers.append(serverid)
                    self.call_hook("server-updown",[serverid,"DOWN"])

    @auto_update(60)
    def update_users(self):
        """Check for expired users and removed them"""

        # Get the store
        store=Store(self.db)

        servers={}
        for serverid in self.get_config_path(self.config,"servers"):
            users=[]
            for user in store.find(Account, server=unicode(serverid), enabled=True):
                if (datetime.datetime.now() - user.created).seconds >= int(self.get_config_path(self.config,"servers",serverid,"timelimit")):
                    users.append(user.username)

            if len(users) != 0:
                servers[serverid]=users

        if len(servers) != 0:
            self.delete_users(servers)

        return True

    @auto_update(120)
    def update_servers(self):
        """Try to contact all the servers"""

        for serverid in self.get_config_path(self.config,"servers"):
            if serverid in self.disabled_servers:
                continue

            try:
                # Connect to the server
                ssh = self.get_ssh(serverid)

                # Test the connection
                ssh.exec_command("echo OK")

                # Disconnect from server
                ssh.close()

                # It worked, so we can remove the server from the blacklist
                if serverid in self.unavailable_servers:
                    self.LOGGER.error("Server '%s' is back online." % serverid)
                    self.unavailable_servers.remove(serverid)
                    self.call_hook("server-updown",[serverid,"UP"])
            except:
                if serverid not in self.unavailable_servers:
                    self.LOGGER.error("Server '%s' is unavailale." % serverid)
                    self.unavailable_servers.append(serverid)
                    self.call_hook("server-updown",[serverid,"DOWN"])

    def set_disabled(self, serverid, state):
        """Disable a server"""

        if serverid not in self.get_config_path(self.config,"servers"):
            return False

        if state == True:
            if serverid in self.disabled_servers:
                return False

            if serverid in self.unavailable_servers:
                self.unavailable_servers.remove(serverid)

            self.disabled_servers.append(serverid)
            self.LOGGER.info("Server '%s' has been disabled." % serverid)
        else:
            if serverid not in self.disabled_servers:
                return False

            self.disabled_servers.remove(serverid)
            self.LOGGER.info("Server '%s' has been re-enabled." % serverid)

        return True

    ## Export the functions

    def json_handler(self,query,client):
        reply={}

        # Check if we at least have a function name
        if 'action' not in query:
            reply['status']="error"
            reply['message']=-1
            return reply

        # Check if function is exported over JSON
        if query['action'] not in self.json_functions():
            reply['status']="error"
            reply['message']=-3
            return reply

        # Standard function without parameters, just return output
        reply['status']='ok'
        function=getattr(self,query['action'])
        function_params=inspect.getargspec(function).args
        function_params.remove('self')

        # Set the source
        query['source']="json:%s" % client

        # FIXME: Hack for backward compatibility for clients without
        # language field (expires 2012/10 at the latest)
        if query['action'] == 'create_user' and 'locale' not in query:
            query['locale']="None"

        attrib=[]
        for param in function_params:
            if param not in query:
                reply['status']="error"
                reply['message']=-2
                return reply
            attrib.append(query[param])

        reply['message']=function(*attrib)

        return reply

    def rpc_functions(self):
        return [
            'delete_user',
            'set_disabled',
            'list_users',
        ]

    def json_functions(self):
        return [
            'create_user',
            'list_everything',
            'list_locales',
            'list_package_blacklist',
            'list_packages',
            'list_servers',
        ]
