#! /usr/bin/env python3

# Copyright (C) 2010-2017 Progetto Fuss <info@fuss.bz.it>
#                         Elena Grandi <elena@truelite.it>,
#                         Christopher R. Gabriel <cgabriel@truelite.it>
# Copyright (C) 2017-2023 The FUSS Project <info@fuss.bz.it>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program or from the site that you downloaded it
# from; if not, write to the Free Software Foundation, Inc., 59 Temple
# Place, Suite 330, Boston, MA  02111-1307   USA

import argparse
import logging
from logging.handlers import SysLogHandler
import os
import shutil
import socket
import subprocess
import sys
import xmlrpc.client

from gettext import gettext as _

from octofuss import helpers

import apt

import psutil

import yaml

VERSION = apt.cache.Cache().get('fuss-client').installed.version

ansible_data_path = "/usr/share/fuss-client"
USB_LABEL = 'portachiavi'
MOUNT = '/mnt/portachiavi'
BEEP_COMMAND_SAFE = [
    'beep',
    '-f', '440.0', '-n',
    '-f', '493.9', '-n',
    '-f', '523.2',
    ]
BEEP_COMMAND_UNSAFE = [
    'beep',
    '-f', '349.2', '-d', '200', '-r', '3',
    ]


class KerberosKeyGenerationError(Exception):
    """
    Raised when generating the kerberos key fails.
    """


class Client(object):

    def __init__(self):
        self.server = None
        self.clusters = []
        self.mounted_usb = None
        self._parse_args()
        self._setup_logging()

    def _setup_logging(self):
        if self.args.verbose:
            log_level = logging.INFO
        elif self.args.debug:
            log_level = logging.DEBUG
        else:
            log_level = logging.WARNING

        logging.basicConfig(
            level=log_level,
            format='%(levelname)s %(message)s'
            )
        if self.args.syslog:
            syslog = SysLogHandler()
            syslog.setLevel(log_level)
            logging.getLogger('').addHandler(syslog)

    def _test_connection(self):
        if self.server is None:
            self.server = self._get_server()
        url = 'http://%s:13400/clientdeploy' % self.server
        cluster_url = 'http://%s:13400/octofuss' % self.server
        if not self.clusters:
            try:
                self.clusters = helpers.list_cluster(cluster_url)
            except ConnectionRefusedError:
                logging.warning(
                    "Could not connect to server to get the list of clusters"
                )
                self.server = None
                self.clusters = []
                return
        s = xmlrpc.client.Server(url)
        try:
            # This will set variables such as root_ssh_key and ldap_*
            self.conf = s.get_conf()
        except Exception as e:
            logging.warning(str(e))
            self.server = None
            self.clusters = []

    def _get_server(self):
        if self.args.useserver:
            return self.args.useserver
        servers = helpers.discover_server()
        if len(servers) == 1:
            return servers[0]
        elif len(servers) == 0:
            print(_(
                "Could not find a fuss server, maybe try to specify one "
                "with -s"
            ))
            sys.exit(1)
        else:
            header = "We found several Fuss Server on this Network\n" + \
                "\n" + "Please choose the one you want to use:"
            return self.choice_in_list(servers, header)

    def _get_cluster(self):
        cluster = self.args.groupjoin
        if not cluster:
            if len(self.clusters) == 1:
                cluster = self.clusters[0]
            elif len(self.clusters) >= 1:
                header = "This server has several workstation groups\n" + \
                    "\n" + "Please choose the one desired for this machine:"
                cluster = self.choice_in_list(self.clusters, header)
        return cluster

    def _get_server_ssh_key(self):
        os.makedirs(MOUNT, exist_ok=True)
        try:
            subprocess.check_call([
                'mount',
                '/dev/disk/by-label/' + USB_LABEL,
                MOUNT,
                ])
        except subprocess.CalledProcessError as e:
            # if we can't mount the usb key, we can still ask for a
            # password
            print(_(
                "Could not mount an usb key with server keys: {e}.\n" +
                "We will have to ask the root password to identity " +
                "with the server."
                ).format(e=e))
            return None
        else:
            self.mounted_key = MOUNT
        key_path = os.path.join(
            MOUNT,
            'server_key',
            'client-ed25519')
        if not os.path.isfile(key_path):
            print(_(
                "Could not find a key in {path}.\n" +
                "We will have to ask the root password to identity " +
                "with the server."
                ).format(path=key_path))
            return None
        return key_path

    def _get_kerberos_key(self):
        # if we want to change the hostname the kerberos key will have
        # to be regenerated anyway.
        if self.args.hostname:
            for f in ('/etc/krb5.keytab', '/root/krb5.keytab'):
                try:
                    os.remove(f)
                except FileNotFoundError:
                    # it's ok if the file didn'e exist in the first place
                    pass
        if (
            os.path.isfile('/etc/krb5.keytab') or
            os.path.isfile('/root/krb5.keytab')
        ):
            # krb5.keytab is already existing, nothing to do
            return False
        hostname = socket.gethostname()
        # try to find an ssh key on an usb stick
        if self.args.ssh_key:
            self.ssh_key = self.args.ssh_key
        else:
            self.ssh_key = self._get_server_ssh_key()

        # get the kerberos key
        if self.args.user:
            print(_(
                """
                Authenticating on the server.

                We are now going to let the server know that this client
                is authorized; do to so we will have to enter the
                password for user {user} a few times.

                Depending on the server load there could be a long
                delay; this is perfectly normal.
                """
                ).format(user=self.args.user))

            res = subprocess.call([
                'ssh',
                '-t',
                '{user}@proxy'.format(user=self.args.user),
                'sudo add_client_principal {hostname}'.format(
                    hostname=hostname,
                    )
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
            res = subprocess.call([
                'ssh',
                '-t',
                '{user}@proxy'.format(user=self.args.user),
                'sudo mv /root/{hostname}.keytab ~{user}/'.format(
                    hostname=hostname,
                    user=self.args.user,
                    )
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
            res = subprocess.call([
                'scp',
                '{user}@proxy:{hostname}.keytab'.format(
                    user=self.args.user,
                    hostname=hostname,
                    ),
                '/root/krb5.keytab',
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
        else:
            print(_(
                """
                Authenticating on the server.

                We are now going to let the server know that this client
                is authorized; do to so we will have to enter the root
                password for the server twice.

                Depending on the server load there could be a long
                delay; this is perfectly normal.
                """
                ).format(user=self.args.user))
            if self.ssh_key:
                ssh_opts = ['-i', self.ssh_key]
            else:
                ssh_opts = []

            res = subprocess.call([
                'ssh',
                ] + ssh_opts + [
                'root@proxy',
                'add_client_principal {hostname}'.format(hostname=hostname),
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
            res = subprocess.call([
                'scp',
                ] + ssh_opts + [
                'root@proxy:{hostname}.keytab'.format(
                    hostname=hostname,
                    ),
                '/root/krb5.keytab',
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
            res = subprocess.call([
                'ssh',
                ] + ssh_opts + [
                'root@proxy',
                'rm /root/{hostname}.keytab'.format(hostname=hostname),
                ])
            if res != 0:
                raise KerberosKeyGenerationError()
        if getattr(self, 'mounted_key', False):
            try:
                subprocess.check_call([
                    'umount',
                    self.mounted_key
                    ])
            except subprocess.CalledProcessError:
                print(_(
                    "We have successfully authenticated on the server " +
                    "but we weren't able to umount the USB key; please " +
                    "do so manually before removing it."
                    ))
                subprocess.call(BEEP_COMMAND_UNSAFE)
            else:
                print(_(
                    "We have successfully authenticated on the server " +
                    "and umounted the USB key; it is now safe to " +
                    "remove it from the computer."
                    ))
                subprocess.call(BEEP_COMMAND_SAFE)
        else:
            print(_(
                "We have successfully authenticated on the server; " +
                "you can now umount any USB key you used to provide an " +
                "ssh key and remove it."
                ))
            subprocess.call(BEEP_COMMAND_UNSAFE)

        return True

    def _set_hostname(self, hname):
        current = socket.gethostname()
        if current == hname:
            return
        with open('/etc/hosts', 'r') as rfp:
            with open('/etc/hosts.fussnew', 'w') as wfp:
                for line in rfp.readlines():
                    wfp.write(line.replace(current, hname))
        shutil.move('/etc/hosts.fussnew', '/etc/hosts')
        subprocess.call([
            'hostnamectl',
            'set-hostname',
            hname
            ])
        subprocess.call([
            'service',
            'networking',
            'force-reload'
            ])

    def choice_in_list(self, o_list, header):
        """
        Show a list of options to the user and return their choice


        """
        print(header)
        for i, opt in enumerate(o_list):
            print(i, " - ", opt)
        print("")
        choice = None
        while choice not in range(len(o_list)):
            choice = input("Your choice? (enter the server number) ")
            try:
                choice = int(choice)
            except ValueError:
                choice = None
        return o_list[choice]

    def prevent_multiple(self):
        for p in psutil.process_iter(["name"]):
            # zombie processes fail in providing their cmdline, but we
            # don't care about them anyway
            try:
                cmdline = " ".join(p.cmdline())
            except psutil.ZombieProcess:
                continue
            # only matching fuss-client in the cmdline also gets
            # octofuss-client and will never run.
            if (
                "fuss-client" in cmdline and
                "ansible-playbook" in cmdline
            ):
                logging.error(
                    "It seems that fuss-client is already running: "
                    "multiple simultaneous runs are not supported."
                )
                sys.exit(1)

    def _run(self, add_to_cluster=True, cloud_only=False):
        self.prevent_multiple()
        logging.info("Adding the machine")
        if self.args.hostname:
            self._set_hostname(self.args.hostname)
        self._test_connection()
        if not self.server:
            logging.error("No valid server found!")
            sys.exit(2)
        if add_to_cluster:
            cluster = self._get_cluster()
        else:
            cluster = None
        os.chdir(ansible_data_path)
        ansible_options = [
            'fuss-client',
            '-i', 'localhost,',
            '-c', 'local',
            '--force-handlers',
            '-e', 'server={}'.format(self.server),
            '-e', 'fuss_client_version={}'.format(VERSION)
            ]
        if cluster:
            ansible_options.extend(['-e', 'cluster={}'.format(cluster)])
        if self.args.slick_greeter:
            ansible_options.extend(['-e', 'slick_greeter=true'])
        if 'root_ssh_key' in self.conf:
            ansible_options.extend([
                '-e',
                'root_ssh_key="{}"'.format(self.conf['root_ssh_key'])
                ])
        if 'ldap_server' in self.conf:
            ansible_options.extend([
                '-e',
                'ldap_server={}'.format(self.conf['ldap_server']),
                '-e',
                'ldap_base={}'.format(self.conf['ldap_base']),
                ])
        if self.args.domain:
            ansible_options.extend([
                '-e',
                'domain={}'.format(self.args.domain)
                ])
        elif 'domain_name' in self.conf:
            ansible_options.extend([
                '-e',
                'domain={}'.format(self.conf['domain_name'])
                ])
        if self.args.unofficial:
            ansible_options.extend([
                '-e',
                '{"unofficial": true}'
                ])
        if self.args.iso:
            ansible_options.extend([
                '-e',
                '{"iso": true}'
                ])
        if self.args.locale:
            ansible_options.extend([
                '-e',
                'locale_default={}'.format(self.args.locale)
                ])
        if self.args.keyboard:
            ansible_options.extend([
                '-e',
                'keyboard_default={}'.format(self.args.keyboard)
                ])      
        if self.args.light:
            ansible_options.extend([
                '-e',
                '{"light": true}'
                ])
        if self.args.wifi_ssid:
            ansible_options.extend([
                '-e',
                'wifissid={}'.format(self.args.wifi_ssid)
            ])
        if self.args.wifi_pass:
            ansible_options.extend([
                '-e',
                'wifipass={}'.format(self.args.wifi_pass)
            ])
        if self.args.dryrun:
            ansible_options.extend(['--check'])
        else:
            if not cloud_only:
                try:
                    self._get_kerberos_key()
                except KerberosKeyGenerationError:
                    print(_(
                        """
                        Key generation failed.

                        If this was because of user error (e.g. a wrong
                        password) please try again.

                        If you were asked for the root password and don't
                        have it, you can use ``-u USERNAME`` to specify a user
                        that is enabled to run sudo on the server.

                        """
                        ))
                    sys.exit(1)
        if cloud_only:
            playbook = "cloud.yml"
        else:
            playbook = "joined.yml"
        os.execvp(
            os.path.join(ansible_data_path, playbook),
            ansible_options
            )

    def add(self):
        self._run()

    def upgrade(self):
        if self.args.groupjoin:
            print(_(
                "Warning: --groupjoin is ignored when running " +
                "fuss-client --upgrade, please run fuss-client --add " +
                "instead."
                ))
        previous_run_facts = {}
        try:
            with open("/etc/ansible/facts.d/fuss_client.fact") as fp:
                previous_run_facts = yaml.safe_load(fp)
        except FileNotFoundError:
            pass

        variant = previous_run_facts.get("variant")
        if variant == "standalone":
            self.standalone()
        elif variant == "cloud":
            self._run(add_to_cluster=False, cloud_only=True)
        else:
            self._run(add_to_cluster=False)

    def cloud(self):
        print(_(
            "Warning: running fuss-client --cloud will set up a "
            "script that deletes all contents of the /home directory "
            "at shutdown or reboot.\n"
            "Press enter to continue"
        ))
        if not self.args.non_interactive:
            input()
        self._run(cloud_only=True)

    def standalone(self):
        self.prevent_multiple()
        logging.info("Configuring a standalone machine")
        os.chdir(ansible_data_path)
        ansible_options = [
            'fuss-client',
            '-i', 'localhost,',
            '-c', 'local',
            '--force-handlers',
            '-e', 'fuss_client_version={}'.format(VERSION),
            ]
        if self.args.domain:
            ansible_options.extend([
                '-e',
                'domain={}'.format(self.args.domain)
                ])
        if self.args.unofficial:
            ansible_options.extend([
                '-e',
                '{"unofficial": true}'
                ])
        if self.args.slick_greeter:
            ansible_options.extend(['-e', 'slick_greeter=true'])
        if self.args.locale:
            ansible_options.extend([
                '-e',
                'locale_default={}'.format(self.args.locale)
                ])
        if self.args.keyboard:
            ansible_options.extend([
                '-e',
                'keyboard_default={}'.format(self.args.keyboard)
                ])      
        if self.args.iso:
            ansible_options.extend([
                '-e',
                '{"iso": true}'
                ])
        if self.args.light:
            ansible_options.extend([
                '-e',
                '{"light": true}'
                ])
        if self.args.fuss_zone:
            ansible_options.extend([
                '-e',
                'fuss_zone_override={}'.format(self.args.fuss_zone)
                ])
        if self.args.dryrun:
            ansible_options.extend(['--check'])
        os.execvp(
            os.path.join(ansible_data_path, 'standalone.yml'),
            ansible_options
            )

    def listavail(self):
        logging.info("Listing available clusters on {}".format(self.server))
        self._test_connection()
        if not self.server:
            logging.error("No valid server found!")
            sys.exit(2)
        for c in self.clusters:
            print(c)

    def _parse_args(self):
        parser = argparse.ArgumentParser(
            description=_(
                "Connect to an [Octo]Fuss server to configure a workstation"
                ),
            )

        # Main actions: add / upgrade / standalone / list groups
        parser.set_defaults(func=self.add)
        group = parser.add_mutually_exclusive_group()
        group.add_argument(
            '-a', '--add',
            action="store_const",
            const=self.add,
            dest='func',
            help=_("Add the machine to a Fuss Network")
            )
        group.add_argument(
            '-U', '--upgrade',
            action="store_const",
            const=self.upgrade,
            dest='func',
            help=_(
                "Update configuration for a machine connected "
                "to a Fuss Network"
            )
            )
        group.add_argument(
            '-S', '--standalone',
            action="store_const",
            const=self.standalone,
            dest='func',
            help=_("Configure a standalone machine")
            )
        group.add_argument(
            '-c', '--cloud',
            action="store_const",
            const=self.cloud,
            dest='func',
            help=_("Add the machine to a Fuss Network")
            )
        group.add_argument(
            '-l', '--listgroups',
            action="store_const",
            const=self.listavail,
            dest='func',
            help=_("list available clusters/groups on server"),
            )

        # Client configuration
        parser.add_argument(
            '--slick-greeter',
            action="store_true",
            default=True,
            help=_("Use slick-greeter instead of lightdm-gtk-greeter"),
            )
        parser.add_argument(
            '--no-slick-greeter',
            action="store_false",
            dest='slick_greeter',
            help=_("Use lightdm-gtk-greeter instead of slick-greeter"),
            )
        parser.add_argument(
            '-t', '--timeout',
            help=_("Network search timeout (seconds)"),
            )
        parser.add_argument(
            '-g', '--groupjoin',
            metavar="CONFSERVERGROUP",
            help=_(
                "inform the server to which computers group "
                "we want to belong to"
                ),
            )
        parser.add_argument(
            '-s', '--useserver',
            metavar="SERVERHOST",
            help=_("connect now to this host"),
            )
        parser.add_argument(
            '-d', '--domain',
            help=_("local domain. only specify if autodetection fails."),
            )
        parser.add_argument(
            '-H', '--hostname',
            help=_("Set the local hostname before configuring the client."),
            )
        parser.add_argument(
            '-u', '--user',
            help=_("sudo-capable user on the server."),
            )
        parser.add_argument(
            '-k', '--ssh-key',
            help=_("path to an ssh key"),
            )
        parser.add_argument(
            '--unofficial',
            help=_("Also include contrib and non-free."),
            action="store_true",
            )
        parser.add_argument(
            '--locale',
            help=_("default locale for the machine"),
            )
        parser.add_argument(
            '--keyboard',
            help=_("default keyboard layout for the machine"),
            )
        parser.add_argument(
            '--iso',
            help=_("Ignore configurations that don't work in a chroot"),
            action="store_true",
            )
        parser.add_argument(
            '--light',
            help=_("Skip installing heavyweight dependencies"),
            action="store_true",
            )
        parser.add_argument(
            '--wifi-ssid',
            help=_("Wireless SSID to use instead of ethernet"),
        )
        parser.add_argument(
            '--wifi-pass',
            help=_("Wireless password to use instead of ethernet"),
        )
        parser.add_argument(
            '--fuss-zone',
            help=_("Fuss zone for standalone machines"),
        )
        parser.add_argument(
            '--non-interactive',
            action="store_true",
            help=_("Avoid blocking for interactions if possible."),
            )

        # Logging
        parser.add_argument(
            '-v', '--verbose',
            action="store_true",
            help=_("verbose output"),
            )
        parser.add_argument(
            '--debug',
            action="store_true",
            help=_("debug output"),
            )
        parser.add_argument(
            '--syslog',
            action="store_true",
            help=_("send loggin output to syslog"),
            )
        parser.add_argument(
            '--dryrun',
            action="store_true",
            help=_("simulate actions"),
            )

        self.args = parser.parse_args()
        if self.args.fuss_zone and self.args.func != self.standalone:
            parser.error(
                "--fuss-zone can only be used with the --standalone option."
            )

    def main(self):
        # Run it
        self.args.func()


if __name__ == '__main__':
    Client().main()
