#!/usr/bin/python3 -sE
# -*- coding: utf-8 -*-
"""
@author: Tomas Hozza <thozza@redhat.com>
@author: Pavel Šimerda <psimerda@redhat.com>
"""

import os
import sys
import fcntl
import shutil
import glob
import subprocess
import logging
import logging.handlers
import struct
import signal
import json

import gi
gi.require_version('NM', '1.0')

from gi.repository import NM

# Python compatibility stuff
if not hasattr(os, "O_CLOEXEC"):
    os.O_CLOEXEC = 0x80000

DEVNULL = open(os.devnull, "wb")

log = logging.getLogger()
log.setLevel(logging.INFO)
log.addHandler(logging.handlers.SysLogHandler())
log.addHandler(logging.StreamHandler())

# NetworkManager reportedly doesn't pass the PATH environment variable.
os.environ['PATH'] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"

# maximum negative cache TTL set by dnssec-trigger script on setup
UNBOUND_MAX_NEG_CACHE_TTL = 5

class UserError(Exception):
    pass


def pidof(process_name):
    """
    Get pids for process with given name

    :param process_name: String with name of process to get PIDs of
    :return: list with PIDs represented as int. If there is no such process, the list is empty
    """
    pids = list()
    try:
        output = subprocess.check_output(['pidof', process_name])
        p = [int(pid) for pid in output.decode().strip().split()]
        pids.extend(p)
    except subprocess.CalledProcessError:
        # There is no process with given name
        pass
    return pids


class Lock:
    """Lock used to serialize the script"""

    path = "/run/dnssec-trigger/lock"

    def __init__(self):
        # We don't use os.makedirs(..., exist_ok=True) to ensure Python 2 compatibility
        dirname = os.path.dirname(self.path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        self.lock = os.open(self.path, os.O_WRONLY | os.O_CREAT | os.O_CLOEXEC, 0o600)

    def __enter__(self):
        fcntl.lockf(self.lock, fcntl.LOCK_EX)

    def __exit__(self, t, v, tb):
        fcntl.lockf(self.lock, fcntl.LOCK_UN)


class Config:
    """Global configuration options"""

    path = "/etc/dnssec-trigger/dnssec.conf"

    bool_options = {
        "debug": False,
        "validate_connection_provided_zones": True,
        "add_wifi_provided_zones": False,
        "use_vpn_global_forwarders": False,
        "use_resolv_conf_symlink": False,
        "use_resolv_secure_conf_symlink": False,
        "use_private_address_ranges": True,
        "set_search_domains": False,
        "keep_positive_answers": False,
    }

    def __init__(self):
        try:
            with open(self.path) as config_file:
                for line in config_file:
                    if '=' in line:
                        option, value = [part.strip() for part in line.split("=", 1)]
                        if option in self.bool_options:
                            self.bool_options[option] = (value == "yes")
        except IOError:
            pass
        log.debug(self)

    def __getattr__(self, option):
        return self.bool_options[option]

    def __str__(self):
        return "<Config {}>".format(self.bool_options)

    @property
    def flush_command(self):
        return "flush_negative" if self.keep_positive_answers else "flush_zone"


config = Config()
if config.debug:
    log.setLevel(logging.DEBUG)


class ConnectionList:
    """List of NetworkManager active connections"""

    nm_connections = None

    def __init__(self, client, only_default=False, only_vpn=False, skip_wifi=False):
        # Cache the active connection list in the class
        if not client.get_nm_running():
            raise UserError("NetworkManager is not running.")
        if self.nm_connections is None:
            self.__class__.nm_connections = client.get_active_connections()
        self.skip_wifi = skip_wifi
        self.only_default = only_default
        self.only_vpn = only_vpn
        log.debug(self)

    def __repr__(self):
        return "<ConnectionList(only_default={only_default}, only_vpn={only_vpn}, skip_wifi={skip_wifi}, connections={})>".format(list(self), **vars(self))

    def __iter__(self):
        for item in self.nm_connections:
            connection = Connection(item)
            # Skip connections that should be ignored
            if connection.ignore:
                continue
            # Skip connections without servers
            if not connection.servers:
                continue
            # Skip WiFi connections if appropriate
            if self.skip_wifi and connection.is_wifi:
                continue
            # Skip non-default connections if appropriate
            if self.only_default and not connection.is_default:
                continue
            if self.only_vpn and not connection.is_vpn:
                continue
            yield connection

    def get_zone_connection_mapping(self):
        result = {}
        for connection in self:
            for zone in connection.zones:
                if zone in result:
                    result[zone] = self.get_preferred_connection(result[zone], connection)
                else:
                    result[zone] = connection
        return result

    @staticmethod
    def get_preferred_connection(first, second):
        # Prefer VPN connection
        if second.is_vpn and not first.is_vpn:
            return second
        if first.is_vpn and not second.is_vpn:
            return first
        # Prefer default connection
        if second.is_default and not first.is_default:
            return second
        if first.is_default and not second.is_default:
            return first
        # Prefer first connection
        return first


class Connection:
    """Representation of a NetworkManager active connection"""

    def __init__(self, connection):
        devices = connection.get_devices()

        if connection.get_vpn():
            self.type = "vpn"
        elif not devices:
            self.type = "ignore"
        elif devices[0].get_device_type().value_name == "NM_DEVICE_TYPE_WIFI":
            self.type = "wifi"
        else:
            self.type = "other"

        self.is_default = bool(connection.get_default() or connection.get_default6())
        self.uuid = connection.get_uuid()

        self.zones = []
        self.servers = []

        ip4_config = connection.get_ip4_config()
        if ip4_config is not None:
            self.zones += list(set(ip4_config.get_domains()+ip4_config.get_searches()))
            self.servers += ip4_config.get_nameservers()

        ip6_config = connection.get_ip6_config()
        if ip6_config is not None:
            self.zones += ip6_config.get_domains()
            self.servers += ip6_config.get_nameservers()

    def __repr__(self):
        return "<Connection(uuid={uuid}, type={type}, default={is_default}, zones={zones}, servers={servers})>".format(**vars(self))

    @property
    def ignore(self):
        return self.type == "ignore"

    @property
    def is_vpn(self):
        return self.type == "vpn"

    @property
    def is_wifi(self):
        return self.type == "wifi"


class UnboundZoneConfig:
    """A dictionary-like proxy object for Unbound's forward zone configuration."""

    def __init__(self):
        subprocess.check_call(["unbound-control", "status"], stdout=DEVNULL, stderr=DEVNULL)
        self.cache = {}
        for line in subprocess.check_output(["unbound-control", "list_forwards"]).decode().split('\n'):
            if line:
                fields = line.split(" ")
                name = fields.pop(0)[:-1]
                if fields[0] == 'IN':
                    fields.pop(0)
                if fields.pop(0) in ('forward', 'forward:'):
                    fields.pop(0)
                secure = False
                if fields and fields[0] == '+i':
                    secure = True
                    fields.pop(0)
                self.cache[name] = set(fields), secure
        log.debug(self)

    def __repr__(self):
        return "<UnboundZoneConfig(data={cache})>".format(**vars(self))

    def __iter__(self):
        return iter(self.cache)

    def add(self, zone, servers, secure):
        """Install a forward zone into Unbound."""

        self._commit(zone, set(servers), secure)

        log.info("Connection provided zone '{}' ({}): {}".format(
            zone, "validated" if secure else "insecure", ', '.join(servers)))

    def remove(self, zone):
        """Remove a forward zone from Unbound."""

        self._commit(zone, None, None)

    def _commit(self, name, servers, secure):
        # FIXME: Older versions of unbound don't print +i for insecure zones
        # and thus we cannot see whether the zone has changed or not. Therefore
        # we have no other chance than to re-add existing zones as well.
        #
        # old_servers, old_secure = self.cache.get(name, [None, None])
        # if servers, secure == old_servers, old_secure:
        #     log.debug("Connection provided zone '{}' already set to {} ({})".format(name, servers, 'secure' if old_secure else 'insecure'))
        #     return

        if servers:
            self.cache[name] = servers, secure
            self._control(["forward_add"] + ([] if secure else ["+i"]) + [name] + list(servers))
            # Unbound doesn't switch an insecure zone to a secure zone when "+i" is
            # specified and there is no "-i" to add a secure zone explicitly.
            if secure:
                self._control(["insecure_remove", name])
        else:
            if name in self.cache:
                del self.cache[name]
            self._control(["forward_remove", name])
        self._control([config.flush_command, name])
        self._control(["flush_requestlist"])

        log.debug(self)

    @staticmethod
    def _control(args):
        log.debug("unbound-control: {}".format(args))
        subprocess.check_call(["unbound-control"] + args, stdout=DEVNULL, stderr=DEVNULL)


class UnboundLocalZoneConfig:
    """A dictionary-like proxy object for Unbound's local zone configuration."""

    def __init__(self):
        subprocess.check_call(["unbound-control", "status"], stdout=DEVNULL, stderr=DEVNULL)
        self.cache = {}
        for line in subprocess.check_output(["unbound-control", "list_local_zones"]).decode().split('\n'):
            if line:
                fields = line.split(" ")
                name = fields.pop(0).rstrip(".")
                type = fields.pop(0)
                self.cache[name] = type
        log.debug(self)

    def __repr__(self):
        return "<UnboundLocalZoneConfig(data={cache})>".format(**vars(self))

    def __iter__(self):
        return iter(self.cache)

    def add(self, zone, type):
        """Install a local zone into Unbound."""
        self.cache[zone] = type
        self._control(["local_zone", zone, type])
        log.debug(self)

    def remove(self, zone):
        """Remove a local zone from Unbound."""
        if self.cache.pop(zone, None):
            self._control(["local_zone_remove", zone])
            log.debug(self)

    @staticmethod
    def _control(args):
        log.debug("unbound-control: {}".format(args))
        subprocess.check_call(["unbound-control"] + args, stdout=DEVNULL, stderr=DEVNULL)


class Store:
    """A proxy object to access stored zones or global servers."""

    def __init__(self, name):
        self.name = name
        self.cache = set()
        self.path = os.path.join("/run/dnssec-trigger", name)
        self.path_tmp = self.path + ".tmp"

        try:
            with open(self.path) as zone_file:
                for line in zone_file:
                    line = line.strip()
                    if line:
                        self.cache.add(line)
        except IOError:
            pass
        log.debug(self)

    def __repr__(self):
        return "<Store(name={name}, {cache})>".format(**vars(self))

    def __iter__(self):
        # Don't return the set itself, as it doesn't support update during
        # iteration.
        return iter(list(self.cache))

    def __bool__(self):
        return bool(self.cache)

    def add(self, zone):
        """Add zone to the cache."""

        self.cache.add(zone)
        log.debug(self)

    def update(self, zones):
        """Commit a new set of items and return True when it differs"""

        zones = set(zones)

        if zones != self.cache:
            self.cache = set(zones)
            log.debug(self)
            return True

        return False

    def remove(self, zone):
        """Remove zone from the cache."""

        self.cache.remove(zone)
        log.debug(self)

    def commit(self):
        """Write data back to disk."""

        # We don't use os.makedirs(..., exist_ok=True) to ensure Python 2 compatibility
        dirname = os.path.dirname(self.path_tmp)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        with open(self.path_tmp, "w") as zone_file:
            for zone in self.cache:
                zone_file.write("{}\n".format(zone))
        os.rename(self.path_tmp, self.path)


class Application:
    resolvconf = "/etc/resolv.conf"
    resolvconf_tmp = "/etc/.resolv.conf.dnssec-trigger"
    resolvconf_secure = "/etc/resolv-secure.conf"
    resolvconf_secure_tmp = "/etc/.resolv-secure.conf.dnssec-trigger"
    resolvconf_backup = "/run/dnssec-trigger/resolv.conf.backup"
    resolvconf_trigger = "/run/dnssec-trigger/resolv.conf"
    resolvconf_trigger_tmp = resolvconf_trigger + ".tmp"
    resolvconf_networkmanager = "/run/NetworkManager/resolv.conf"

    resolvconf_localhost_contents = "# Generated by dnssec-trigger-script\nnameserver 127.0.0.1\n"

    rfc1918_reverse_zones = [
        "c.f.ip6.arpa",
        "d.f.ip6.arpa",
        "168.192.in-addr.arpa",
        ] + ["{}.172.in-addr.arpa".format(octet) for octet in range(16, 32)] + [
        "10.in-addr.arpa"]

    def __init__(self, argv):
        if len(argv) > 1 and argv[1] == '--debug':
            argv.pop(1)
            log.setLevel(logging.DEBUG)
        if len(argv) > 1 and argv[1] == '--async':
            argv.pop(1)
            if os.fork():
                sys.exit()
        if len(argv) < 2 or not argv[1].startswith('--'):
            self.usage()
        try:
            self.method = getattr(self, "run_" + argv[1][2:].replace('-', '_'))
        except AttributeError:
            self.usage()

        self.client = NM.Client().new()

    def nm_handles_resolv_conf(self):
        if not self.client.get_nm_running():
            log.debug("NetworkManager is not running")
            return False
        try:
            with open("/etc/NetworkManager/NetworkManager.conf") as nm_config_file:
                for line in nm_config_file:
                    if line.strip() in ("dns=none", "dns=unbound"):
                        log.debug("NetworkManager doesn't handle resolv.conf")
                        return False
        except IOError:
            pass
        log.debug("NetworkManager handles resolv.conf")
        return True

    def usage(self):
        raise UserError("Usage: dnssec-trigger-script [--debug] [--async] --prepare|--setup|--update|--update-global-forwarders|--update-connection-zones|--cleanup")

    def run(self):
        log.debug("Running: {}".format(self.method.__name__))
        self.method()

    def _check_resolv_conf(self, path):
        try:
            with open(path) as source:
                if source.read() != self.resolvconf_localhost_contents:
                    log.info("Rewriting {!r}!".format(path))
                    return False;
                return True
        except IOError:
            return False

    def _write_resolv_conf(self, path):
        self._try_remove(path)
        with open(path, "w") as target:
            target.write(self.resolvconf_localhost_contents)

    def _install_resolv_conf(self, path, path_tmp, symlink=False):
        if symlink:
            self._try_remove(path_tmp)
            os.symlink(self.resolvconf_trigger, path_tmp)
            self._try_set_mutable(path)
            os.rename(path_tmp, path)
        elif not self._check_resolv_conf(path):
            self._write_resolv_conf(path_tmp)
            self._try_set_mutable(path)
            os.rename(path_tmp, path)
            self._try_set_immutable(path)

    def _try_remove(self, path):
        self._try_set_mutable(path)
        try:
            os.remove(path)
        except OSError:
            pass

    def _try_set_immutable(self, path):
        subprocess.call(["chattr", "+i", path], stdout=DEVNULL, stderr=DEVNULL)

    def _try_set_mutable(self, path):
        if os.path.exists(path) and not os.path.islink(path):
            subprocess.call(["chattr", "-i", path], stdout=DEVNULL, stderr=DEVNULL)

    def _restart_nm(self):
        if os.path.exists(self.resolvconf_networkmanager):
            os.symlink(self.resolvconf_networkmanager, self.resolvconf)
        # Sending SIGHUP will cause NM to reload config and write the /etc/resolv.conf
        elif self.client.get_version() >= '1.0.3':
            log.debug("Sending SIGHUP to NM to rewrite the resolv.conf")
            nm_pids = pidof('NetworkManager')
            for pid in nm_pids:
                os.kill(pid, signal.SIGHUP)
        else:
            try:
                subprocess.check_call(["systemctl", "--ignore-dependencies", "try-restart", "NetworkManager.service"], stdout=DEVNULL, stderr=DEVNULL)
            except subprocess.CalledProcessError:
                subprocess.check_call(["/etc/init.d/NetworkManager", "restart"], stdout=DEVNULL, stderr=DEVNULL)

    def run_prepare(self):
        """Prepare for starting dnssec-trigger

        Called by the service manager before starting dnssec-trigger daemon.
        """

        # Backup resolv.conf when appropriate
        if not self.nm_handles_resolv_conf():
            try:
                log.info("Backing up {} as {}...".format(self.resolvconf, self.resolvconf_backup))
                shutil.move(self.resolvconf, self.resolvconf_backup)
            except IOError as error:
                log.warning("Cannot back up {!r} as {!r}: {}".format(self.resolvconf, self.resolvconf_backup, error.strerror))

        # Make sure dnssec-trigger daemon doesn't get confused by existing files.
        self._try_remove(self.resolvconf)
        self._try_remove(self.resolvconf_secure)
        self._try_remove(self.resolvconf_trigger)

    def run_setup(self):
        """Set up resolv.conf with localhost nameserver

        Called by dnssec-trigger.
        """

        # Set the maximum negative cache TTL
        self._unbound_set_negative_cache_ttl(UNBOUND_MAX_NEG_CACHE_TTL)

        if config.set_search_domains:
            zones = set(sum((connection.zones for connection in ConnectionList(self.client)), []))
            log.info("Search domains: " + ' '.join(zones))
            self.resolvconf_localhost_contents = self.__class__.resolvconf_localhost_contents
            self.resolvconf_localhost_contents += "search {}\n".format(' '.join(zones))

        self._install_resolv_conf(self.resolvconf_trigger, self.resolvconf_trigger_tmp, False)
        self._install_resolv_conf(self.resolvconf, self.resolvconf_tmp, config.use_resolv_conf_symlink)
        self._install_resolv_conf(self.resolvconf_secure, self.resolvconf_secure_tmp, config.use_resolv_secure_conf_symlink)

    def run_restore(self):
        """Restore resolv.conf with original data

        Called by dnssec-trigger or internally as part of other actions.
        """

        self._try_remove(self.resolvconf)
        self._try_remove(self.resolvconf_secure)
        self._try_remove(self.resolvconf_trigger)

        log.info("Recovering {}...".format(self.resolvconf))
        if self.nm_handles_resolv_conf():
            # try to make NM to rewrite resolv.conf or as a last resort restart it
            self._restart_nm()
        else:
            try:
                shutil.move(self.resolvconf_backup, self.resolvconf)
            except IOError as error:
                log.warning("Cannot restore {!r} from {!r}: {}".format(self.resolvconf, self.resolvconf_backup, error.strerror))

    def run_cleanup(self):
        """Clean up after dnssec-trigger daemon

        Called by the service manager after stopping dnssec-trigger daemon.
        """

        self.run_restore()

        stored_zones = Store('zones')
        stored_servers = Store('servers')
        unbound_zones = UnboundZoneConfig()

        # provide upgrade path for previous versions
        old_zones = glob.glob("/run/dnssec-trigger/????????-????-????-????-????????????")
        if old_zones:
            log.info("Reading zones from the legacy zone store")
            with open("/run/dnssec-trigger/zones", "a") as target:
                for filename in old_zones:
                    with open(filename) as source:
                        log.debug("Reading zones from {}".format(filename))
                        for line in source:
                            stored_zones.add(line.strip())
                        os.remove(filename)

        log.debug("clearing unbound configuration")
        for zone in stored_zones:
            unbound_zones.remove(zone)
            stored_zones.remove(zone)
        for server in stored_servers:
            stored_servers.remove(server)
        stored_zones.commit()
        stored_servers.commit()

    @property
    def global_forwarders(self):
        connections = None
        if config.use_vpn_global_forwarders:
            connections = list(ConnectionList(self.client, only_vpn=True))
        if not connections:
            connections = list(ConnectionList(self.client, only_default=True))

        return sum((connection.servers for connection in connections), [])

    def run_update(self):
        """Update unbound and dnssec-trigger configuration."""

        self.run_update_global_forwarders()
        self.run_update_connection_zones()

    @staticmethod
    def _unbound_set_negative_cache_ttl(ttl):
        CMD = ["unbound-control", "set_option", "cache-max-negative-ttl:", str(ttl)]
        log.debug(" ".join(CMD))
        subprocess.check_call(CMD, stdout=DEVNULL, stderr=DEVNULL)

    @staticmethod
    def dnssec_trigger_control(args):
        log.debug("dnssec-trigger-control: {}".format(args))
        subprocess.check_call(["dnssec-trigger-control"] + args, stdout=DEVNULL, stderr=DEVNULL)

    def run_update_global_forwarders(self):
        """Configure global forwarders using dnssec-trigger-control."""

        with Lock():
            self.dnssec_trigger_control(["status"])

            servers = Store('servers')

            if servers.update(self.global_forwarders):
                UnboundZoneConfig._control([config.flush_command, "."])
                self.dnssec_trigger_control(["submit"] + list(servers))
                servers.commit()
                log.info("Global forwarders: {}".format(' '.join(servers)))
            else:
                log.info("Global forwarders: {} (unchanged)".format(' '.join(servers)))

    def run_update_connection_zones(self):
        """Configures forward zones in the unbound using unbound-control."""

        with Lock():
            connections = ConnectionList(self.client, skip_wifi=not config.add_wifi_provided_zones).get_zone_connection_mapping()
            unbound_zones = UnboundZoneConfig()
            unbound_local_zones = UnboundLocalZoneConfig()
            stored_zones = Store('zones')

            # Remove any zones managed by dnssec-trigger that are no longer
            # valid.
            log.debug("removing zones that are no longer valid")
            for zone in stored_zones:
                # leave zones that are provided by some connection
                if zone in connections:
                    continue

                if zone in self.rfc1918_reverse_zones:
                    # if zone is private address range reverse zone and we are congifured to use them, leave it
                    if config.use_private_address_ranges:
                        continue
                    # otherwise add Unbound local zone of type 'static' like Unbound does and remove it later
                    else:
                        unbound_local_zones.add(zone, "static")

                # Remove all zones that are not in connections except OR
                # are private address ranges reverse zones and we are NOT
                # configured to use them
                if zone in unbound_zones:
                    unbound_zones.remove(zone)
                stored_zones.remove(zone)

            # Install all zones coming from connections except those installed
            # by other means than dnssec-trigger-script.
            log.debug("installing connection provided zones")
            for zone in connections:
                # Reinstall a known zone or install a new zone.
                if zone in stored_zones or zone not in unbound_zones:
                    unbound_zones.add(zone, connections[zone].servers, secure=config.validate_connection_provided_zones)
                    stored_zones.add(zone)

            # Configure forward zones for reverse name resolution of private addresses.
            # RFC1918 zones will be installed, except those already provided by connections
            # and those installed by other means than by dnssec-trigger-script.
            # RFC19118 zones will be removed if there are no global forwarders.
            if config.use_private_address_ranges:
                log.debug("configuring RFC 1918 private zones")
                for zone in self.rfc1918_reverse_zones:
                    # Ignore a connection provided zone as it's been already
                    # processed.
                    if zone in connections:
                        continue
                    if self.global_forwarders:
                        # Reinstall a known zone or install a new zone.
                        log.debug("Installing RFC 1918 private zone '%s' not present in unbound or connections", zone)
                        if zone in stored_zones or zone not in unbound_zones:
                            unbound_zones.add(zone, self.global_forwarders, secure=False)
                            stored_zones.add(zone)
                            unbound_local_zones.remove(zone)
                    else:
                        # There are no global forwarders, therefore remove the zone
                        log.debug("Removing RFC 1918 private zone '%s' since there are no global forwarders", zone)
                        if zone in unbound_zones:
                            unbound_zones.remove(zone)
                        if zone in stored_zones:
                            stored_zones.remove(zone)
                        unbound_local_zones.add(zone, "static")

            stored_zones.commit()

    def run_update_all(self):
        cons = ConnectionList(self.client)

        json_struct = {}
        connections = []
        connection = {}

        for con in cons:
            connection["type"] = con.type
            connection["default"] = con.is_default
            connection["zones"] = con.zones
            connection["servers"] = con.servers

            connections.append(connection)
            connection = {}

        json_struct["connections"] = connections
        print("Running update all with these connections:")
        print(json.dumps(json_struct, sort_keys=True, indent=4, separators=(',', ': ')))
        try:
            self.dnssec_trigger_control(["update_all", json.dumps(json_struct)])
        except subprocess.CalledProcessError:
            pass


def main():
    try:
        Application(sys.argv).run()
    except UserError as error:
        log.error(error)
        exit(1)
    except subprocess.CalledProcessError as error:
        if len(error.cmd) == 2 and error.cmd[0].endswith('-control') and error.cmd[1] == "status":
            log.error("Cannot connect to {}.".format(error.cmd[0][:-8]))
            exit(1)
        else:
            raise


if __name__ == "__main__":
    main()

