#!/usr/bin/env python3
# -*- python -*-
#
#  File: fuss-server-config
#
#  Copyright (C) 2007-2016 Christopher R. Gabriel <cgabriel@truelite.it>,
#                          Elena Grandi <elena@truelite.it>,
#                          Progetto Fuss <info@fuss.bz.it>
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import getpass
import logging
import os
import re
import shutil
import subprocess
import sys

from gettext import gettext as _

import apt
import netaddr
import netifaces
import ruamel.yaml

xwin = False
try:
    import gtk
    r = gtk.gdk.display_get_default()
    if r:
        xwin = True
    else:
        xwin = False
    import gnome
    import gnome.ui
except ImportError:
    xwin = False

try:
    input = raw_input
except NameError:
    pass

ansible_data_path = "/usr/share/fuss-server/"
conf_file = '/etc/fuss-server/fuss-server.yaml'
clean_config_file = '/usr/share/doc/fuss-server/examples/fuss-server.yaml.example'

try:
    VERSION = apt.cache.Cache().get('fuss-server').installed.version
except AttributeError:
    VERSION = 'dev'

class Configuration(object):
    known = {
        "localnet": [
            _("Local network address"),
            _("The format is netaddr/cidr, ex. 192.168.1.0/24")
            ],
        "domain": [
            _("Domain name"),
            _("The domain for this network, ex. 'institute.lan'")
            ],
        "pass": [
            _("Master password"),
            _("The master password for this server")
            ],
        "geoplace": [
            _("Locality"),
            _("Locality e/o address name, ex. 'Bolzano'")
            ],
        "fuss_zone":[
            _("FUSS Zone"),
            _(
                "Please press ENTER to use 'Prov-BZ' if you're in a school "
                "within the Province of Bozen-Bolzano: The inventory data of "
                "server and clients will be collected by the Intendenza "
                "Scolastica Italiana (Autonomous Provincie of  "
                "Bozen-Bolzano) for statistical purposes. "
                "\n"
                "Otherwise enter any other string to prevent inventory data "
                "collection and keep the default fuss-server behaviour. "
                "Full support for additional zones will be added in "
                "version 13 of fuss-server. "
                "\n"
                "See https://fuss-tech-guide.readthedocs.io/it/latest/server/fuss_zone.html "
                "for more details in this feature."
              ),
            ],
        "workgroup": [
            _("Windows Workgroup"),
            _("The Windows WorkGroup for this network, ex. 'institute'")
            ],
        "dhcp_range": [
            _("DHCP Server Range"),
            _("The IP range of address given by the DHCP Server, ex. '192.168.1.10 192.168.1.100'")
            ],
        "external_ifaces": [
            _("WAN Interface"),
            _("The WAN interface(s) of the server, ex. 'eth0'")
            ],
        "internal_ifaces": [
            _("LAN Interfaces"),
            _("The LAN interface(s) of the server, ex. 'eth1 eth2'")
            ],
        "hotspot_iface": [
            _("Hot Spot Interface"),
            _("The Hotspot interface of the server, ex. 'eth3'")
            ],
        "hotspot_network": [
            _("Hot Spot Network (CIDR)"),
            _("The Hotspot network of the server, ex. '10.1.0.0/24'")
            ],
        }

    def __init__(self, c_file=conf_file, reconf_all=False, cp_mandatory=False):
        self.c_file = c_file
        self.reconf_all = reconf_all
        self.cp_mandatory = cp_mandatory

    def check(self):
        """
        Check the current configuration.

        Return a list of entries that are missing or problematic
        """
        if not os.path.isfile(self.c_file):
            logging.error("Can't find configuration file - exiting")
            # TODO: this should probably raise an exception
            sys.exit(3)
        found_keys = []
        missing_keys = []
        for key, value in self.data.items():
            found_keys.append(key)
            # If we don't have a method to check for the validity of
            # data, it means that every non-empty value is valid.
            method = getattr(self, '_check_{}'.format(key), lambda x: bool(x))
            if self.reconf_all or not method(value):
                missing_keys.append(key)
        for key in self.known:
            if key not in found_keys:
                missing_keys.append(key)
        if 'localnet' not in missing_keys:
            for key in self._crosscheck_network():
                if key not in missing_keys:
                    missing_keys.append(key)
        for key in self._crosscheck_hotspot():
            if key not in missing_keys:
                missing_keys.append(key)
        return missing_keys

    def _crosscheck_network(self):
        wrong = []
        all_networks = {
            netaddr.IPNetwork(x['addr']+'/'+x['netmask']): iface
            for iface in netifaces.interfaces()
            for x in netifaces.ifaddresses(iface).get(
                netifaces.AF_INET,
                [])
            }
        localnet = netaddr.IPNetwork(self.data['localnet'])
        if not self.data['internal_ifaces']:
            logging.warning(_("No internal interfaces are configured"))
            return ['internal_ifaces']
        try:
            if all_networks[localnet] not in self.data['internal_ifaces']:
                logging.warning(_(
                    "The value for local network {localnet} is not " +
                    "configured on any local interface ({ifaces})"
                    ).format(
                        localnet=str(localnet),
                        ifaces=str(self.data['internal_ifaces'])
                        )
                    )
                wrong = ['localnet', 'internal_ifaces']
        except KeyError:
            logging.warning("No interface found for localnet {}".format(
                self.data['localnet']
                ))
            wrong = ['localnet']
        for ip in self.data['dhcp_range'].split():
            try:
                range_addrs = netaddr.IPAddress(ip)
            except ValueError:
                wrong.append('dhcp_range')
                break
            if range_addrs not in localnet:
                wrong.append('dhcp_range')
                break
        return wrong

    def _crosscheck_hotspot(self):
        # Either both hotspot values should be filled or none (and if
        # none, we're done with the crosscheck.
        fields = ('hotspot_iface', 'hotspot_network')
        filled = [bool(self.data[f]) for f in fields]
        if not all(filled):
            if any(filled):
                return [f for i, f in enumerate(fields) if not filled[i]]
            else:
                return []

        if self.data['hotspot_iface'] in self.data['internal_ifaces']:
            logging.warning(_("Hot spot interface cannot be the same as a LAN interface"))
            return ['hotspot_iface']
        if self.data['hotspot_iface'] in self.data['external_ifaces']:
            logging.warning(_("Hot spot interface cannot be the same as a WAN interface"))
            return ['hotspot_iface']
        if 'tun' in self.data['hotspot_iface']:
            logging.warning(_("Hot spot interface cannot be a tunnel interface"))
            return ['hotspot_iface']

# TODO: fare controllo solo sulle "altre" interfacce (da internal_ifaces)
#        ip_route = subprocess.check_output(['ip', 'route'])
#        hotspot_net = netaddr.IPNetwork(self.data['hotspot_network'])
#        hs_net_s = str(hotspot_net.network).encode('utf-8')
#
#        for line in ip_route.split(b'\n'):
#            if line.strip().startswith(hs_net_s) and b'tun' not in line:
#                logging.warning((
#                    "Network {} already used\n" +
#                    "Please choose another one"
#                    ).format(str(hotspot_net.network)))
#                return ['hotspot_network']

        return []

    def _check_external_ifaces(self, value):
        if not isinstance(value, list):
            return False
        for iface in value:
            if iface not in netifaces.interfaces():
                logging.warning("Interface {} is not available".format(
                    iface))
                return False
        return True

    _check_internal_ifaces = _check_external_ifaces

    def _check_localnet(self, value):
        """
        Localnet should be a valid address in CIDR format
        """
        if not value:
            return False
        if not len(value.split('/')) == 2:
            return False
        try:
            netaddr.IPNetwork(value)
        except netaddr.AddrFormatError:
            return False
        return True

    def _check_dhcp_range(self, value):
        """
        dhcp_range should be made of valid ips
        """
        if not value:
            return False
        ips = value.split(' ')
        if len(ips) != 2:
            return False
        for ip in ips:
            try:
                netaddr.IPAddress(ip)
            except (netaddr.AddrFormatError, ValueError):
                return False
        return True

    def _check_pass(self, value):
        """
        pass should not contain any of &, \, /, $ chars, nor be composed
        of just numbers
        """
        if not value:
            return False
        try:
            int(value)
        except ValueError:
            # if we can't get a number out of the password everything is
            # fine
            pass
        else:
            logging.warning("password must not be composed by just numbers")
            return False
        # add more forbidden char if neeeded
        forbiddenchars = set('$\/&')
        if any((c in forbiddenchars) for c in value):
            logging.warning("password must not contain &, \\, /, or $")
            return False
        else:
            return True

    def _check_domain(self, value):
        """
        domain should be made up of two alphanumeric names separated by
        one dot.
        The TLD .local isn't allowed, because it's reserved to mDNS
        """
        if not value:
            return False
        allowed = re.compile("^[\w]+\.[\w]+$")
        if not allowed.match(value):
            return False
        if value.endswith('.local'):
            logging.warning(".local domains are not allowed")
            return False
        return True

    def _check_workgroup(self, value):
        """
        workgroup should be made of alphanumeric
        """
        if not value:
            return False
        allowed = re.compile("^[\w]+$")
        if not allowed.match(value):
            logging.warning("Domain must contains only alphanumeric")
            return False
        return True

    def _check_hotspot_iface(self, value):
        if not value:
            if self.cp_mandatory:
                return False
            else:
                # empty values are allowed, in case no hotspot is present
                return True
        if value not in netifaces.interfaces():
            logging.warning("Interface {} is not available".format(
                value))
            return False
        return True

    def _check_hotspot_network(self, value):
        """
        Hotspot network should be a valid address in CIDR format
        """
        if not value:
            if self.cp_mandatory:
                return False
            else:
                # empty values are allowed, in case no hotspot is present
                return True
        if not len(value.split('/')) == 2:
            return False
        try:
            netaddr.IPNetwork(value)
        except netaddr.AddrFormatError:
            return False
        return True

    def load(self, bootstrap=False):
        """
        Load configuration data from a file.
        """
        if bootstrap or not os.path.exists(self.c_file):
            logging.info("Creating a new configuration file with empty values")
            confdir = os.path.dirname(os.path.realpath(self.c_file))
            if not os.path.isdir(confdir):
                os.makedirs(confdir)
            shutil.copyfile(clean_config_file, os.path.realpath(self.c_file))
        with open(self.c_file) as fp:
            self.data = ruamel.yaml.load(fp, ruamel.yaml.RoundTripLoader)
        if not self.data:
            logging.error(
                "The configuration file seems to be empty.\n" +
                "Please delete it to restart from a new valid one."
                )
            # TODO: this should probably raise an exception
            sys.exit(3)
        invalid = False
        for k in self.known:
            if k not in self.data and k != "fuss_zone":
                logging.error(
                    "Missing value in the configuration file: {}".format(k)
                    )
                invalid = True
        if invalid:
            logging.error(
                "Please add the missing values to the configuration file\n" +
                "or delete it to start from a clean one."
                )
            # TODO: this should probably raise an exception
            sys.exit(3)

    def save(self):
        """
        Save configuration data to file, setting safe permissions.
        """
        os.chmod(self.c_file, 0o640)
        os.umask(0o27)
        with open(self.c_file, "w") as fp:
            ruamel.yaml.dump(
                self.data,
                stream=fp,
                Dumper=ruamel.yaml.RoundTripDumper
                )
        os.umask(0o22)

    def ask(self, missing_conf):
        logging.info("Asking for configuration")
        if xwin:
            entries = {}

            def build_druid_page(key, question, help, default=""):
                page = gnome.ui.DruidPageStandard()
                page.set_title(question)
                v = gtk.VBox()
                page.append_item(help, v, '')

                h = gtk.HBox()
                h.pack_start(gtk.Label(_("Please enter you choice")))
                entry = gtk.Entry()
                if 'assword' in question:
                    entry.set_visibility(False)
                entry.set_text(default)
                entries[key] = entry
                h.pack_start(entry)
                v.pack_start(h)
                page.show_all()
                return page

            def completed(widget, pars):
                for i in entries.keys():
                    if "ifaces" in i:
                        self.data[i] = entries[i].get_text().split()
                    else:
                        self.data[i] = ruamel.yaml.safe_load(
                            entries[i].get_text()
                            )
                self.save()
                gtk.main_quit()

            if len(missing_conf) > 0:
                w = gtk.Window()
                w.set_default_size(500, 500)
                w.set_title(_("Fuss Server Configuration"))
                w.connect("delete_event", gtk.main_quit)

                druid = gnome.ui.Druid()
                druid.connect("cancel", gtk.main_quit)
                w.add(druid)
                start_page = gnome.ui.DruidPageEdge(0)
                start_page.set_title(_("Fuss Server Configuration"))
                start_page.set_text(_("Welcome to the Fuss Server configuration"))
                druid.add(start_page)
                for i in missing_conf:
                    if 'ifaces' in i and self.data[i]:
                        current = " ".join((str(x) for x in self.data[i]))
                    else:
                        current = str(self.data[i]) or ''

                    druid.add(build_druid_page(
                        i,
                        self.known[i][0],
                        self.known[i][1],
                        current
                        ))

                end_page = gnome.ui.DruidPageEdge(1)
                end_page.set_title(_("Fuss Server Configuration"))
                end_page.set_text(_("All done! Thank you!"))
                end_page.connect("finish", completed)
                druid.add(end_page)

                w.show_all()
            else:
                d = gtk.MessageDialog(
                    parent=None,
                    flags=gtk.DIALOG_MODAL,
                    type=gtk.MESSAGE_ERROR, buttons=gtk.BUTTONS_OK
                    )
                d.set_markup(_("Looks like you've already configured this Fuss Server.\n\nUse the '-r' option to reconfigure it all"))
                d.show_all()
                d.run()
                d.destroy()
                sys.exit(8)
            gtk.main()
        else:
            if len(missing_conf) > 0:
                for i in missing_conf:
                    print("#"*80)
                    print(_("Please insert"), self.known[i][0])
                    print("")
                    print(self.known[i][1])
                    if 'ifaces' in i and self.data[i]:
                        current = " ".join((str(x) for x in self.data[i]))
                    else:
                        current = str(self.data.get(i, "")) or ''
                    if current:
                        print(_("Current value")+": ", current)
                        print("")
                    if "assword" in self.known[i][0]:
                        pw1 = getpass.getpass()
                        if getpass.getpass(_("Confirm password: ")) != pw1:
                            logging.warning(_("Password mismatch"))
                        else:
                            self.data[i] = ruamel.yaml.safe_load(pw1)
                    elif "ifaces" in i:
                        self.data[i] = input(_("Your choice? ")).split()
                    elif "zone" in i:
                        self.data[i] = input(_("Your choice? [Prov-BZ]")) or "Prov-BZ"
                    else:
                        self.data[i] = ruamel.yaml.safe_load(
                            input(_("Your choice? "))
                            )
                self.save()
            else:
                print(_("Looks like you've already configured this Fuss Server."))
                print("")
                print(_("Use the '-r' option to reconfigure it all"))
                sys.exit(8)


def fail_if_not_root():
    if os.getuid() > 0:
        logging.error("Can't execute fuss-server - Are you root?")
        sys.exit(5)


def _config(c, bootstrap=False):
    c.load(bootstrap)
    res = c.check()
    # in any case, only reconfigure everything once, then ask just the
    # missing bits
    c.reconf_all = False
    while len(res) > 0:
        c.ask(res)
        res = c.check()


def configure(args):
    logging.info("Asking for missing configuration")
    if args.configuration_file == conf_file:
        # Usually we can't work except as root, but when working on a
        # different configuration file it is convenient to allow to
        # check and set the configuration as a normal user.
        fail_if_not_root()
    c = Configuration(
        reconf_all=args.reconfigure_all,
        c_file=args.configuration_file
        )
    _config(c, bootstrap=args.bootstrap)


def create(args):
    logging.info("Applying configuration")
    fail_if_not_root()
    c = Configuration()
    _config(c)
    os.chdir(ansible_data_path)
    os.execvp(os.path.join(ansible_data_path, 'create.yml'), [
        'fuss-server',
        '-i', 'localhost,',
        '-c', 'local',
        '--force-handlers',
        '-e', 'fuss_server_version={}'.format(VERSION),
        ])


def upgrade(args):
    logging.info("Upgrading configuration")
    fail_if_not_root()
    # if the captive portal has already been configured, we also include
    # its roles in the upgrade playbook. refs: #977
    captive_portal = False
    if os.path.exists("/etc/ansible/facts.d/fuss_server_cp.fact"):
        captive_portal = True
    c = Configuration(cp_mandatory=captive_portal)
    _config(c)
    os.chdir(ansible_data_path)
    os.execvp(os.path.join(ansible_data_path, 'upgrade.yml'), [
        'fuss-server',
        '-i', 'localhost,',
        '-c', 'local',
        '--force-handlers',
        '-e', 'fuss_server_version={}'.format(VERSION),
        '-e', '{{captive_portal: {}}}'.format(
            "true" if captive_portal else "false"
        )
        ])


def purge(args):
    logging.info("Purging")
    fail_if_not_root()
    c = Configuration()
    _config(c)
    os.chdir(ansible_data_path)
    os.execvp(os.path.join(ansible_data_path, 'purge.yml'), [
        'fuss-server',
        '-i', 'localhost,',
        '-c', 'local',
        '--force-handlers',
        ])


def captive_portal(args):
    logging.info("Applying configuration for a captive portal")
    fail_if_not_root()
    c = Configuration(cp_mandatory=True)
    _config(c)
    os.chdir(ansible_data_path)
    os.execvp(os.path.join(ansible_data_path, 'captive_portal.yml'), [
        'fuss-server',
        '-i', 'localhost,',
        '-c', 'local',
        '--force-handlers',
        '-e', 'fuss_server_version={}'.format(VERSION),
        ])


def test(args):
    logging.info("Testing the server")
    fail_if_not_root()
    os.chdir(ansible_data_path)
    os.execvp(os.path.join(ansible_data_path, 'test.sh'), [
        'fuss-server',
        ])


def self_test(args):
    import unittest

    class TestCheck(unittest.TestCase):
        def setUp(self):
            self.c = Configuration()

        def test_localnet(self):
            self.assertTrue(self.c._check_localnet('192.168.5.23/24'))
            self.assertFalse(self.c._check_localnet(''))
            self.assertFalse(
                self.c._check_localnet('192.168.5.23 255.255.255.0')
                )

        def test_dhcp_range(self):
            self.assertTrue(
                self.c._check_dhcp_range('192.168.5.23 192.168.5.42')
                )
            self.assertFalse(
                self.c._check_dhcp_range('192.168.5.23')
                )
            self.assertFalse(
                self.c._check_dhcp_range('192.168.5.0/24')
                )

        def test_check_domain(self):
            self.assertTrue(
                self.c._check_domain('scuola.lan')
                )
            self.assertFalse(
                self.c._check_domain('this.is.not.valid')
                )
            self.assertFalse(
                self.c._check_domain('scuola.local')
                )
            self.assertTrue(
                self.c._check_domain('local.lan')
                )

        def test_check_workgroup(self):
            self.assertTrue(
                self.c._check_workgroup('workgroup')
                )
            self.assertFalse(
                self.c._check_workgroup('scuola.lan')
                )

        def test_crosscheck_hotspot(self):
            self.c.data = {
                'external_ifaces': ['eth0'],
                'internal_ifaces': ['eth1', 'eth2'],
                'hotspot_iface': 'eth3',
                'hotspot_network': '192.168.5.0/24',
                }

            # All valid values
            self.assertEqual(self.c._crosscheck_hotspot(), [])

            # both hotspot variables empty: valid
            self.c.data['hotspot_iface'] = ''
            self.c.data['hotspot_network'] = ''
            self.assertEqual(self.c._crosscheck_hotspot(), [])

            # only one hotspot variabile empty: invalid
            self.c.data['hotspot_iface'] = 'eth3'
            self.c.data['hotspot_network'] = ''
            self.assertEqual(self.c._crosscheck_hotspot(), ['hotspot_network'])

            self.c.data['hotspot_iface'] = ''
            self.c.data['hotspot_network'] = '192.168.5.0/24'
            self.assertEqual(self.c._crosscheck_hotspot(), ['hotspot_iface'])

            # hotspot interface can't be the same as an internal or
            # external one
            self.c.data['hotspot_iface'] = 'eth0'
            self.assertEqual(self.c._crosscheck_hotspot(), ['hotspot_iface'])
            self.c.data['hotspot_iface'] = 'eth1'
            self.assertEqual(self.c._crosscheck_hotspot(), ['hotspot_iface'])

            # hotspot interface can't be a tun one
            self.c.data['hotspot_iface'] = 'eth1'
            self.assertEqual(self.c._crosscheck_hotspot(), ['hotspot_iface'])

        def test_password(self):
            self.assertTrue(self.c._check_pass('abcdefg'))
            self.assertFalse(self.c._check_pass('abcd&'))
            self.assertFalse(self.c._check_pass('abcd\\'))
            self.assertFalse(self.c._check_pass('abcd/'))
            self.assertFalse(self.c._check_pass('abcd$'))
            self.assertFalse(self.c._check_pass('1234'))
            self.assertFalse(self.c._check_pass(1234))

    suite = unittest.TestLoader().loadTestsFromTestCase(TestCheck)
    unittest.TextTestRunner(verbosity=1).run(suite)


def main():
    parser = argparse.ArgumentParser(
        description='Configure a FUSS server.'
        )
    parser.set_defaults(func=create)
    subparser = parser.add_subparsers(
        title='subcommands',
        description='Run fuss-server <command> -h for help on the subcommands.',
        dest='create'  # this is ignored by python 2.7, but works with 3.4+
        )

    create_parser = subparser.add_parser(
        'create',
        help='install dependencies and configuration'
        )
    create_parser.add_argument(
        '--limit',
        help="Ignored for compatibility"
        )
    create_parser.set_defaults(func=create)

    upgrade_parser = subparser.add_parser(
        'upgrade',
        help='apply a new configuration to an existing fuss-server'
        )
    upgrade_parser.add_argument(
        '--limit',
        help="Ignored for compatibility"
        )
    upgrade_parser.set_defaults(func=upgrade)

    purge_parser = subparser.add_parser(
        'purge',
        help='clean configuration'
        )
    purge_parser.add_argument(
        '--limit',
        help="Ignored for compatibility"
        )
    purge_parser.set_defaults(func=purge)

    configure_parser = subparser.add_parser(
        'configure',
        help='configure configuration'
        )
    configure_parser.add_argument(
        '-r', '--reconfigure-all',
        action="store_true",
        help="Reconfigure all options"
        )
    configure_parser.add_argument(
        '-b', '--bootstrap',
        action="store_true",
        help="Delete all current configuration and start with a new empty file"
        )
    configure_parser.add_argument(
        '-f', '--configuration-file',
        help="Use a different configuration file (for testing)",
        default=conf_file
        )
    configure_parser.set_defaults(func=configure)

    test_parser = subparser.add_parser(
        'test',
        help='test the server configuration'
        )
    test_parser.set_defaults(func=test)

    captive_portal_parser = subparser.add_parser(
        'cp',
        help='install a captive portal'
        )
    captive_portal_parser.add_argument(
        '--limit',
        help="Ignored for compatibility"
        )
    captive_portal_parser.set_defaults(func=captive_portal)

    self_test_parser = subparser.add_parser(
        'selftest',
        help='run tests on this script'
        )
    self_test_parser.set_defaults(func=self_test)

    args = parser.parse_args()
    args.func(args)


if __name__ == '__main__':
    main()
