#!/usr/bin/env python
#
# A tool that provides client and server side options and allows the client
# to send set commands to the server. Replies are sent back to the client in
# the form of json output. Hence this is a limited form of remote execution.
#
# Copyright (c) 2017 by cisco Systems, Inc.
# All rights reserved.
#
import getopt
import imp
import json
import logging
import os
import select
import shlex
import socket
import sys
import traceback
import time
import argparse

plugins = {}


class ApplyCmdClient(object):

    #
    # Pre main, parse arguments.
    #
    def __init__(self):

        self.stdout = ""
        self.stderr = ""
        self.exitcode = ""

        self.bufsz = 1024
        self.socket_client = None
        self.client_timeout_in_secs = 10

        self.handler = logging.StreamHandler()
        self.handler.setLevel(logging.INFO)

        self.logger = logging.getLogger(__name__)
        self.logger.addHandler(self.handler)
        self.logger.setLevel(logging.INFO)

        logging.addLevelName(logging.INFO,
                             "\033[1;31m%s\033[1;0m" %
                             logging.getLevelName(logging.INFO))
        logging.addLevelName(logging.ERROR,
                             "\033[1;41m%s\033[1;0m" %
                             logging.getLevelName(logging.ERROR))

        #
        # Initialize the parser
        #
        description = """
A tool (with client and server side) that allows remote execution of a
set of commands. Such commands should be carefully crafted to disallow
general breakout on the remote server.

The backend commands are implemented via plugins.

One example plugin is the VRF plugin which allows us to remotely do
ip netns add/del <vrf> from XR. By remote here we mean the VM host;
so we're using unix domain sockets to break out of the XR LXC.

Usage:

app_hosting_apply_cmd_client.py --socket <file> --cmd <string> --stdout_path log_file -stderr_path err_file

Commands supported:
"""

        for key, value in plugins.iteritems():
            description += value.description()

        arger = argparse.ArgumentParser(description="")

        arger.add_argument("-socket", "--socket",
                           help="UNIX domain socket filename",
                           required=True)

        arger.add_argument("-cmd", "--cmd", 
                           help="Name of command to run on the server",
                           required=True)

        arger.add_argument("-stdout_path", "--stdout_path", 
                           help="Place server command output in this file")

        arger.add_argument("-stderr_path", "--stderr_path", 
                           help="Place server command errors in this file")

        arger.add_argument("-exitcode", "--exitcode", 
                           help="Place server command exitcode in this file")

        arger.add_argument("-d", "--debug", 
                           help="Add json debugs")

        self.opts = arger.parse_args()

        if self.opts.debug:
            self.logger.setLevel(logging.INFO)
        else:
            self.logger.setLevel(logging.ERROR)

    #
    # Clients wants to send a message to the server.
    #
    def client_send_cmd(self):

        self.client_open_socket()
        self.logger.info("Tx [{0}]".format(self.opts.cmd))
        self.socket_client.send(self.opts.cmd)
        self.client_wait_reply()

    #
    # Client sent a message to the server. Wait for a fixed time for the reply.
    #
    def client_wait_reply(self):
        while True:
            self.logger.info("Waiting on reply")

            ready = select.select([self.socket_client], [], [],
                                  self.client_timeout_in_secs)
            if not ready[0]:
                break

            self.logger.info("Got event on socket")
            data = ""
            self.socket_client.setblocking(1)

            while True:
                try:
                    fragment = self.socket_client.recv(self.bufsz)
                    if fragment is None or fragment == "":
                        if data == "":
                            raise ValueError("Socket {0} read, no data".format(
                                             self.opts.socket))
                        else:
                            self.logger.info("Read end of data")
                            break
                    else:
                        self.logger.info("Read fragment, len {0}".format(
                                         len(fragment)))

                except Exception as exception:
                    raise ValueError("Socket {0} read failed: {1}".format(
                                     self.opts.socket, exception))
                #
                # We have one fragment. Read the rest without blocking now.
                #
                self.socket_client.setblocking(0)

                data += fragment

                self.logger.info("Fragments total len {0}".format(len(data)))

                #
                # Saw this once when the server kept bailing out due to a
                # syntax error. So lets play safe.
                #
                if data == "" and fragment == "":
                    raise ValueError("Socket {0} read empty data".format(
                                     self.opts.socket))
                #
                # Just in case, let's not spinloop reading fragments.
                #
                time.sleep(0.05)

                #
                # Check something has not gone horribly wrong.
                #
                if len(data) > 10 * 1024 * 1024:
                    raise ValueError("Socket {0} read too much data".format(
                                     self.opts.socket))

            if data == "":
                raise ValueError("Socket {0} replied with no data".format(
                                 self.opts.socket))

            self.logger.info("Rx [{0}]".format(data))

            if data:
                output = json.loads(data)
                self.stdout = output["stdout"]
                self.stderr = output["stderr"]
                self.exitcode = output["exitcode"]
                return

        raise ValueError("Socket {0} timed out on read".format(
                         self.opts.socket))

    #
    # Open client socket for communication
    #
    def client_open_socket(self):

        if self.socket_client is not None:
            raise ValueError("Socket {0} already open".format(
                             self.opts.socket))

        if not os.path.exists(self.opts.socket):
            raise ValueError("Socket {0} does not exist".format(
                             self.opts.socket))

        self.socket_client = socket.socket(socket.AF_UNIX,
                                           socket.SOCK_STREAM)

        self.logger.info("Connect to socket {0}".format(
                         self.opts.socket))
        self.socket_client.settimeout(10)
        self.socket_client.connect(self.opts.socket)

    #
    # Close client sockets
    #
    def socket_cleanup(self):

        if self.socket_client is not None:
            self.socket_client.close()
            self.socket_client = None

    #
    # Save the results of the server to disk if asked for.
    #
    def write_results(self):

        if self.opts.stdout_path is not None:
            with open(self.opts.stdout_path, 'w') as f:
                f.write(self.stdout)
        else:
            sys.stdout.write(self.stdout)

        if self.opts.stderr_path is not None:
            with open(self.opts.stderr_path, 'w') as f:
                f.write(self.stderr)
        else:
            sys.stderr.write(self.stderr)

        if self.opts.exitcode is not None:
            with open(self.opts.exitcode, 'w') as f:
                f.write(str(self.exitcode))
        else:
            sys.exit(self.exitcode)

    #
    # Called on success or failure at exit
    #
    def cleanup(self):

        self.socket_cleanup()
        self.write_results()

    def main(self):

        try:
            self.client_send_cmd()

        except Exception as exception:
            self.logger.error(format(exception))
            traceback.print_exc(file=sys.stdout)
            sys.exit(1)

        finally:
            self.cleanup()


if __name__ == '__main__':
    ApplyCmdClient().main()
