#!/usr/bin/python3
# -*- coding: UTF-8 -*-

"""Small daemon program that is used to push DHCP information to a
provisioning server.

DHCP information is passed to it by the dxtorc application.

"""

__license__ = """
    Copyright (C) 2010-2011  Avencall

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along
    with this program; if not, write to the Free Software Foundation, Inc.,
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA..
"""

import errno
import json
import configparser
import logging
import optparse
import os
import urllib.request, urllib.error, urllib.parse
import urllib.parse
import socket
import sys

from xivo.xivo_logging import setup_logging

CONFIG_FILE = '/etc/xivo-dxtora/config.conf'
LOG_FILE_NAME = '/var/log/xivo-dxtora.log'
PID_FILE = '/var/run/xivo-dxtora.pid'
UNIX_SERVER_ADDR = '/var/run/xivo-dxtora.ctl'
PROV_MIME_TYPE = 'application/vnd.proformatique.provd+json'

logger = logging.getLogger('xivo-dxtora')


class DHCPInfoSourceError(Exception):
    """Raised if there's an error while pulling DHCP information."""
    pass


class DHCPInfoSinkError(Exception):
    """Raised if there's an error while pushing DHCP information."""
    pass


class PidFileError(Exception):
    pass


class StreamDHCPInfoSink(object):
    """A destination for DHCP information objects.

    Write the DHCP information as a string to a file object.

    Useful for testing/debugging...

    """
    def __init__(self, fobj):
        self._fobj = fobj

    def close(self):
        pass

    def push(self, dhcp_info):
        self._fobj.write(str(dhcp_info) + '\n')


class ProvServerDHCPInfoSink(object):
    """A destination for DHCP information objects.

    Send the DHCP information to a provisioning server via its REST API.

    """
    def __init__(self, base_uri, username=None, password=None):
        handlers = []
        if username and password:
            pwd_manager = urllib.request.HTTPPasswordMgrWithDefaultRealm()
            pwd_manager.add_password(None, base_uri, username, password)
            digest_handler = urllib.request.HTTPDigestAuthHandler(pwd_manager)
            handlers.append(digest_handler)
        self._opener = urllib.request.build_opener(*handlers)
        self._base_uri = base_uri
        self._resource_uri = None

    def close(self):
        pass

    def _update_resource_uri(self):
        headers = {'Accept': PROV_MIME_TYPE}
        request = urllib.request.Request(self._base_uri, headers=headers)
        f = self._opener.open(request)
        try:
            content = json.load(f)
        finally:
            f.close()

        links = content['links']
        for link in links:
            if link['rel'] == 'dev.dhcpinfo':
                self._resource_uri = urllib.parse.urljoin(self._base_uri, link['href'])
                break
        else:
            raise DHCPInfoSinkError('no link to DHCP info resource on base resource')

    def _do_push(self, dhcp_info, retry_on_404=True):
        if self._resource_uri is None:
            self._update_resource_uri()
        headers = {'Accept': '*/*', 'Content-Type': PROV_MIME_TYPE}
        content = json.dumps({'dhcp_info': dhcp_info}).encode('utf-8')
        request = urllib.request.Request(self._resource_uri, content, headers=headers)
        try:
            f = self._opener.open(request)
            f.read()
            f.close()
        except urllib.error.HTTPError as e:
            code = e.code
            if code == 404 and retry_on_404:
                self._update_resource_uri()
                self._do_push(dhcp_info, False)
            else:
                raise

    def push(self, dhcp_info):
        logger.info('Pushing DHCP info to prov server')
        try:
            self._do_push(dhcp_info)
        except DHCPInfoSinkError:
            raise
        except Exception as e:
            # XXX we are wrapping a bit too much
            raise DHCPInfoSinkError(e)


class UnixSocketDHCPInfoSource(object):
    """A source of DHCP information objects.

    Use a Unix datagram socket to get DHCP information.

    """
    def __init__(self, ctl_file, remove_file=False):
        """Create a new source.

        Raise a socket.error exception if the socket can't be binded to
        ctl_file.

        """
        if remove_file:
            try:
                os.remove(ctl_file)
            except OSError as e:
                pass
        self._ctl_file = ctl_file
        self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
        try:
            self._sock.bind(ctl_file)
        except socket.error:
            self._sock.close()
            raise

    def close(self):
        self._sock.close()
        _remove(self._ctl_file)

    def _check_op(self, op):
        # check that op is one of the 3 valid values
        if op not in ('commit', 'expiry', 'release'):
            raise DHCPInfoSourceError("invalid 'op' value: %s" % op)

    def _check_ip(self, ip):
        # check that ip is a dotted quad ip
        try:
            # Note that inet_aton accept strings with less than three dots.
            socket.inet_aton(ip)
        except socket.error:
            raise DHCPInfoSourceError("invalid 'ip' value: %s" % ip)

    def _check_mac(self, mac):
        # check that mac is a lowercase, colon separated mac
        # TODO if we really care
        pass

    def _check_options(self, options):
        # check that options is a sequence of valid option
        for option in options:
            if len(option) < 3:
                raise DHCPInfoSourceError("invalid 'option' value: too short")
            try:
                num = int(option[:3], 10)
            except ValueError:
                raise DHCPInfoSourceError("invalid 'option' value: not int")
            else:
                if not 0 <= num <= 255:
                    raise DHCPInfoSourceError("invalid 'option' value: invalid code")

    def _decode(self, data):
        """Takes the raw data from a request and return an dhcp_info
        dict ('op', 'ip', 'mac' and 'options').

        """
        lines = [_f.decode() for _f in data.split(b'\n') if _f]
        dhcp_info = {}
        def check_and_add(key, value):
            check_fun = getattr(self, '_check_' + key)
            check_fun(value)
            dhcp_info[key] = value
        try:
            check_and_add('op', lines[0])
            check_and_add('ip', lines[1])
            if dhcp_info['op'] == 'commit':
                check_and_add('mac', lines[2])
                check_and_add('options', lines[3:])
        except IndexError as e:
            raise DHCPInfoSourceError(e)
        else:
            return dhcp_info

    def pull(self):
        """Get a dhcp_info object from the source.

        Note: this is a blocking call.

        """
        logger.info('Pulling DHCP info from unix socket')
        data = self._sock.recvfrom(2048)[0]
        dhcp_info = self._decode(data)
        return dhcp_info


class Agent(object):
    """An agent that reads DHCP info from a source and send it to a sink in
    a loop.

    """
    def __init__(self, source, sink):
        """Create an agent."""
        self._source = source
        self._sink = sink

    def run(self):
        """Run the agent in loop."""
        while True:
            try:
                dhcp_info = self._source.pull()
                logger.info("Pulled DHCP info: (%(op)s, %(ip)s)", dhcp_info)
                logger.debug('DHCP info: %s', dhcp_info)
                self._sink.push(dhcp_info)
            except DHCPInfoSourceError as e:
                logger.error('Error while pulling info from source: %s', e)
            except DHCPInfoSinkError as e:
                logger.error('Error while pushing info to sink: %s', e)
            except Exception as e:
                logger.exception('Unspecified exception')


class PidFile(object):
    def _remove_stale_pid_file(self):
        try:
            fobj = open(self._pid_file)
        except IOError as e:
            if e.errno == errno.ENOENT:
                # pidfile doesn't exist -- do nothing
                pass
            else:
                raise
        else:
            try:
                pid = int(fobj.read().strip())
            finally:
                fobj.close()
            # check if a process with the given pid exist by sending a signal
            # with value 0 (see "man 2 kill" for more info).
            try:
                os.kill(pid, 0)
            except OSError as e:
                if e.errno == errno.ESRCH:
                    # no such process -- remove stale pidfile
                    logger.info('Found stale pidfile %s, removing it', self._pid_file)
                    os.remove(self._pid_file)
                else:
                    raise
            else:
                logger.error('Found fresh pidfile %s', self._pid_file)

    def _create_pid_file(self):
        pid = os.getpid()
        tmp_pid_file = self._pid_file + '.' + str(pid)
        try:
            fpid = open(tmp_pid_file, 'w')
        except EnvironmentError as e:
            raise PidFileError("couldn't create tmp pid file: %s" % e)
        else:
            try:
                fpid.write("%s\n" % pid)
            finally:
                fpid.close()
            try:
                os.link(tmp_pid_file, self._pid_file)
            except EnvironmentError as e:
                raise PidFileError("couldn't create pid file: %s" % e)
            finally:
                os.unlink(tmp_pid_file)

    def __init__(self, pid_file):
        self._pid_file = pid_file
        self._remove_stale_pid_file()
        self._create_pid_file()

    def _remove_pid_file(self):
        _remove(self._pid_file)

    def close(self):
        self._remove_pid_file()


def _read_config_default():
    return {'general':
                {'retries': 0,
                 'foreground': False},
            'prov_server':
                {'username': 'admin',
                 'password': 'admin'}}


def _read_config_from_file(filename):
    config = configparser.RawConfigParser()
    fobj = open(filename)
    try:
        config.read_file(fobj)
    finally:
        fobj.close()

    result = {'general': {}, 'prov_server': {}}
    def _get(section, option, getname='get'):
        if config.has_option(section, option):
            getfun = getattr(config, getname)
            result[section][option] = getfun(section, option)
    _get('general', 'retries', 'getint')
    _get('prov_server', 'dev_mgr_uri')
    _get('prov_server', 'username')
    _get('prov_server', 'password')
    return result


def _read_config_from_commandline():
    parser = optparse.OptionParser()
    parser.add_option('-u', '--user', dest='user',
                      help='user name and password for server authentication')
    parser.add_option('-f', '--foreground', action='store_true', dest='foreground',
                      help="don't daemonize")
    opt, args = parser.parse_args()

    result = {'general': {}, 'prov_server': {}}
    if opt.foreground:
        result['general']['foreground'] = opt.foreground
    if opt.user:
        user, passwd = opt.user.split(':', 1)
        result['prov_server']['username'] = user
        result['prov_server']['password'] = passwd
    if len(args) >= 1:
        result['prov_server']['dev_mgr_uri'] = args[0]
    return result


def _read_config(filename):
    config = _read_config_default()
    def _update(new_config):
        for key in config:
            config[key].update(new_config[key])
    _update(_read_config_from_file(filename))
    _update(_read_config_from_commandline())
    return config


def _remove(file):
    # remove the file if present in a way that doesn't modify the stack
    # trace and doesn't raise an exception
    try:
        pass
    finally:
        try:
            os.remove(file)
        except EnvironmentError:
            pass


def _daemonize():
    try:
        pid = os.fork()
        if pid != 0:
            os._exit(0)
    except OSError as e:
        logger.exception("first fork() failed: %d (%s)", e.errno, e.strerror)
        sys.exit(1)

    os.chdir("/")
    os.setsid()
    os.umask(0)

    try:
        pid = os.fork()
        if pid != 0:
            os._exit(0)
    except OSError as e:
        logger.exception("second fork() failed: %d (%s)", e.errno, e.strerror)
        sys.exit(1)

    devnull_fd = os.open(os.devnull, os.O_RDWR)
    try:
        for fd in (0, 1, 2):
            os.dup2(devnull_fd, fd)
    finally:
        os.close(devnull_fd)


def _sig_handler(signum, frame):
    logger.info('Received signal, exiting.')
    raise SystemExit()


def main():
    import signal

    config = _read_config(CONFIG_FILE)

    setup_logging(LOG_FILE_NAME, config['general']['foreground'], debug=False)

    if 'dev_mgr_uri' not in config['prov_server']:
        logger.error('error: no device manager URI specified. Exiting.')
        sys.exit(1)
    base_uri = config['prov_server']['dev_mgr_uri']

    username = config['prov_server'].get('username')
    password = config['prov_server'].get('password')
    sink = ProvServerDHCPInfoSink(base_uri, username, password)
    try:
        if not config['general']['foreground']:
            _daemonize()
        pidfile = PidFile(PID_FILE)
        try:
            source = UnixSocketDHCPInfoSource(UNIX_SERVER_ADDR, True)
            try:
                signum = signal.SIGTERM
                old_handler = signal.signal(signum, _sig_handler)
                try:
                    agent = Agent(source, sink)
                    agent.run()
                finally:
                    signal.signal(signum, old_handler)
            finally:
                source.close()
        finally:
            pidfile.close()
    finally:
        sink.close()


if __name__ == '__main__':
    main()
