#! python

"""
    orchestrator
  
    This module implements necessary background infra for an orchestrator.
    Written on top of multiprocessing, this module defines necessary class 
    and method to spawn and monitor worker processes.
    Assumptions made:
        When run in system mode, assume that the package distribution will
        be done from the same node on which the system orchestrator is being 
        triggered.
"""

import sys
sys.dont_write_bytecode = True
try:
    from constants import *
    from multiprocessing.managers import BaseManager, SyncManager
    from multiprocessing import Process, JoinableQueue
    import time
    from queue import Queue, Empty
    import json
    import modules.feeder as mf
    from utils import errors
    from threading import Thread
    import os
    import traceback
    from utils import get_logger, inst_utils, utils_orchestrator
    logger = get_logger()

except ImportError:
    import platform
    arch = platform.processor()
    if 'arm' in arch:
        print "Just return with Success for arm based cards."
        sys.exit (0) 
    else:
        import traceback
        exc_info = sys.exc_info()
        TB = traceback.format_exc()
        print (TB)
        sys.exit (-1)

def monitor_queue (runq, nodes2track, helperobj):
    while True:
        data = runq.get()
        if data == None:
            break
        logger.debug("In monitor queue %s"%(data))
        for node in nodes2track:
            if node in data:
                if "Successful run" in data:
                    helperobj.progress_update (node_done = node)
                if "Failed run" in data:
                    helperobj.progress_update (node_abort = node)
    return    

def handle_abort_optim_workflow():
    logger.debug("Handling abort for install source failure in optim mode")
    try:
        mf.handle_abort ()
        for file in mf.rem_on_abort:
            if os.path.isfile (file):
                logger.debug ("Remove file %s created as part of this operation"%(file))
                os.unlink (file)
    except:
        pass
    return

class Orchestrator (object):
   
    def __init__ (self, address, authkey, cfg_file, mode, oper_id):
        self.address = address             # Endpoint address for the orchestrator.
        self.authkey = authkey             # Authentication key to connect to orch.
        self.cfg = cfg_file                # Configuration parameters for orchestrator
        self.helperobj = None              # Object defining parameters used by orch.
        self.mode = mode                   # Local vs system mode for orchestrator.
        self.oper_id = oper_id             # Install id for this orchestration request.
        self.errorQ = Queue()

    def _create_server_manager (self):
        resultqueue = Queue()
        runqueue = Queue()
        class ServerManager (SyncManager): pass
        ServerManager.register ('get_result_queue', callable=lambda:resultqueue)
        ServerManager.register ('get_run_queue', callable=lambda:runqueue)
        ServerManager.register ('get_error_queue', callable=lambda:self.errorQ)
        self.manager = ServerManager(address=self.address, authkey=self.authkey)
        self.manager.start()
        addr = self.manager.address
        logger.info('Manager is listening at address : %s'%(addr,))

        if self.mode == SYSTEM:
            addr_file = SYS_ORCH_ADDR_FILE
        else:
            addr_file = NODE_ORCH_ADDR_FILE

        with open(addr_file, 'w') as fd:
            fd.write (json.dumps(addr))
        return
        
    def start (self):
        os.chdir(os.path.dirname(sys.argv[0]))
        try:
            self.helperobj = mf.Taskfeeder()
            self.helperobj.parsecfgfile (self.cfg, self.mode, self.oper_id) 
            self._create_server_manager ()
        except:
            exc_info = sys.exc_info()
            TB = traceback.format_exc()
            logger.debug(TB)
            raise

        import atexit
        atexit.register (self.helperobj._exit_work_items)

        self.helperobj.progress_update (progress_str = 'Initial config parsed. Stage one orchestration in progress.')
        localq = []
        jobs = []
        tasksq = JoinableQueue()
        self.helperobj.feedtasks (tasksq, localq)
        proc = Process (target = mf._update_job_queue, args = (self.mode, tasksq,))
        proc.start()
        proc.join (STAGE1TIMEOUT)
        if proc.is_alive() or proc.exitcode:
            logger.error('Stage1 not done in stipulated time. Exit with error')
            self.helperobj.progress_update (progress_str = 'Stage one timeout expired.')
            proc.terminate()
            sys.exit (-1)
        if not tasksq.empty():
            logger.error('Stage1 task queue not emptied and stage 1 finished. Exit with error.')
            self.helperobj.progress_update (progress_str = 'Stage one tasks failed to complete.')
            sys.exit (-1)
        tasksq.join ()
        self.helperobj.progress_update (progress_str = 'Stage one orchestration completed.')
        runq = self.manager.get_run_queue()
        index = 0
        nodes2track = []
            
        self.helperobj.progress_update (progress_str = 'Stage two orchestration started.')
        while True:
            work_item = []
            item_key = None
            resultq = self.manager.get_result_queue()
            item = resultq.get ()
            work_item, item_key = self.helperobj._feed_actual_workers (item, self.cfg)
            if item_key:
                name = item_key
                self.helperobj.progress_update (node_ready = name)
                nodes2track.append (item_key)
            else:
                name = str(index)
            p = mf.PrepareWorker (runq, work_item, name, self.helperobj._helper_fn_worker)
            index+=1
            jobs.append (p)
            p.start ()            
            for k in item.keys():
                if k in localq:
                    localq.remove(k)
                
            if not len (localq):
                break

        self.helperobj.progress_update (progress_str = 'Monitor progress for prepare operation.')
        monitor_thread = Thread (target = monitor_queue, args = (runq, nodes2track, self.helperobj,))
        monitor_thread.start()
        proc_exit = False
        job_failed = ''
        for job in jobs:
            job.join(STAGE2TIMEOUT)
            if job.exitcode:
                logger.error("Job %s returned with exitcode %d"%(job, job.exitcode))
                proc_exit = True
                job_failed = job.name
                break
            else:
                logger.info("Done with %s"%(job))
        if proc_exit:
            for job in jobs:
                if job.is_alive():
                    logger.error('Terminate %s for failed operation'%(job))
                    job.terminate()

            while monitor_thread.isAlive():
                runq.put (None)
            
            raise Exception ("Preparation failed on %s"%(job_failed))
            
        runq.put (None)
        monitor_thread.join()
        # Once all tasks are done, go ahead and run any post tasks.
        self.helperobj.progress_update (progress_str = 'Prepare operation completed.')
        self.helperobj._post_orch_work ()

    def shut (self):
        self.manager.shutdown()

if __name__ == '__main__':
    mode = sys.argv[1]
    oper_id = sys.argv[2]
    authkey = INSTAUTHKEY
   
    if mode == SYSTEM:
        cfg = SYSTEM_CFG
        addr_file = SYS_ORCH_ADDR_FILE
    elif mode == NODE:
        import shutil
        cfg = NODE_CFG
        cfg_alt = NODE_CFG_ALT
        if os.path.isfile (cfg):
            if os.path.isfile (cfg_alt):
                os.unlink (cfg_alt)
            shutil.copy (cfg, cfg_alt)	
        elif not os.path.isfile (cfg):
            if os.path.isfile (cfg_alt):
                shutil.copy (cfg_alt, cfg)
                os.unlink (cfg_alt)
            else:
                raise Exception ("Node config could not be retrieved")
        addr_file = NODE_ORCH_ADDR_FILE
    try:
        logger.info ("Orchestrator started in %s mode\n"%(mode))
    except:
        logger.error("Unable to initialize orchestrator logger\n")
        sys.exit (-1)
    try:
        inst_utils.precheck_env ()
    except:
        pass
    ip = inst_utils.get_local_node_ip()   
    orch = Orchestrator ((ip, 0), authkey, cfg, mode, oper_id)
    try:
        orch.start ()
    except:
        handle_abort_optim_workflow ()
        exc_info = sys.exc_info()
        TB = traceback.format_exc()
        logger.error (TB)
        mdata_dict = {}
        queue = orch.manager.get_error_queue()
        if mode == SYSTEM:
            with open(UPDATE_STATUS_FILE, 'r') as fErrIn:
                 mdata_dict = json.load(fErrIn)
        mdata_dict['Errors'] = {}     
        if not queue.empty():
            logger.debug('Updating %s file with errors', UPDATE_STATUS_FILE)
            while True:
                try:
                    ip, err_msg = queue.get_nowait()
                    logger.error(ip)
                    logger.error(err_msg)
                    if not mdata_dict['Errors'].has_key (ip):
                        mdata_dict['Errors'][ip] = []
                    if err_msg not in mdata_dict['Errors'][ip]:
                        mdata_dict['Errors'][ip].append(err_msg)
                        #mdata_dict['Errors'][ip] = err_msg
                except Empty:
                    break
        with open(UPDATE_STATUS_FILE, 'w') as fErrOut:
            json.dump(mdata_dict, fErrOut, indent=2)   
            fErrOut.flush()
            logger.debug ('Copying %s file to following RP\'s XR vms : %s '%(UPDATE_STATUS_FILE, str(orch.helperobj.orchobj.xrrpvms)))
            utils_orchestrator.copy_status_update_rp_nodes (orch.helperobj.orchobj.xrrpvms)
        logger.error ("Error while orchestrating in %s mode\n"%(mode))
        sys.exit (-1)
    finally:
        orch.shut()
        if os.path.isfile (addr_file):
            os.remove (addr_file)
        if mode == SYSTEM and os.path.isfile (UPDATE_STATUS_FILE):
            mdata_dict = {}
            with open (UPDATE_STATUS_FILE, 'r') as fd:
                mdata_dict = json.load (fd)
            logger.debug ('Prep update at end of operation %s'%(mdata_dict,))
            os.remove (UPDATE_STATUS_FILE)

    print("Orchestration done")

