import threading
import time

from appfw.utils.infraexceptions import QueueException
import logging
from .tasks import Task
import traceback
log = logging.getLogger("taskmgmt")
from datetime import datetime
from appfw.runtime.stats import StatsCollector


class Worker(object):
    def __init__(self, task_queue, task_registry, sleep_interval=1.0, save_result=True):
        self.delay = sleep_interval
        self.storageiface = task_queue
        self.task_registry = task_registry
        self.save_result = save_result

    def run(self):
        task_id = None
        try:
            #log.debug("Tasks executor, about to read from the queue")
            task_id = self.storageiface.dequeue()
            #log.debug("Tasks executor , read task from non blocking queue %s", task_id)
        except QueueException as ex:
            if StatsCollector.getInstance().enabled:
                taskmgr_registry = StatsCollector.getInstance().get_statsregistry("TASKMGR","task")
                taskmgr_registry.counter("error_cnt").inc()
                taskmgr_registry.gauge("last_error").set_value(str(ex))
            log.exception('Queue exception %s', str(ex))
        except Exception as ex:
            if StatsCollector.getInstance().enabled:
                taskmgr_registry = StatsCollector.getInstance().get_statsregistry("TASKMGR","task")
                taskmgr_registry.counter("error_cnt").inc()
                taskmgr_registry.gauge("last_error").set_value(str(ex))
            log.exception('Unknown exception dequeueing task %s.', str(ex))
        try:
            if task_id:
                task = self.task_registry.get_task(task_id)
                if task:
                    self.handle_task(task)
                self.storageiface.queue_task_done()
        except Exception as ex:
            if StatsCollector.getInstance().enabled:
                taskmgr_registry = StatsCollector.getInstance().get_statsregistry("TASKMGR","task")
                taskmgr_registry.counter("error_cnt").inc()
                taskmgr_registry.gauge("last_error").set_value(str(ex))
            log.exception("Caught exception %s in worker thread while handling the task", str(ex))
        self.sleep()

    def sleep(self):
        time.sleep(self.delay)

    def ready_to_run(self, task):
        return task.execute_time is None or task.execute_time <= time.time()

    def requeue_task(self, task):
        if task.is_revoked:
            log.debug("Task %s is revoked", task.task_id)
            return False
        else:
            try:
                self.storageiface.enqueue(task.task_id)
            except:
                log.exception("requeue task exception")

    def handle_task(self, task):
        if not self.ready_to_run(task):
            self.requeue_task(task)
        elif not task.is_revoked:
            #log.debug('Task %s running' % task)
            self.process_task(task)
            if task.execute_time is not None and task.is_periodic is True:
                #log.debug("Task executed, requeuing the task since its periodic with updated execute time")
                task.execute_time = time.time() + task.interval
                self.requeue_task(task)

    def process_task(self, task):
        if not isinstance(task, Task):
            raise TypeError('Unknown object: %s' % task)
        #log.info('Executing %s' % task)
        result = None
        start = time.time()
        try:
            result = task.execute()
        except Exception as ex:
            log.exception("Exception while executing task %s", str(ex))
            if self.save_result:
                import traceback
                task_error = {}
                task_error["task"] = task.serialize()
                task_error["exception"] = str(ex)
                task_error["traceback"] = traceback.format_exc()
                task.set_task_data(task_error)
        finally:
            duration = time.time() - start
            if StatsCollector.getInstance().enabled:
                taskmgr_registry = StatsCollector.getInstance().get_statsregistry("TASKMGR", "task_id-%s" % task.task_id)
                taskmgr_registry.histogram("runtime").add(duration)
            #log.debug('Task %s ran in %0.3fs' % (task, duration))

        if self.save_result:
            task_result = {}
            task_result["duration"] = duration
            task_result["task_execution_started_at"] = start
            task_result["result"] = result
            task.set_task_data(task_result)
        #Delete the task to avoid memory leak if its non-periodic
        if not task.is_periodic:
            self.task_registry.unregister(task)

class TaskExecutor(object):
    """
    This class configures and starts the worker threads that are necessary for executing tasks
    This can be initialized based on optimal performance required
    Currently by default only one worker thread is configured
    """
    def __init__(self,  task_queue, registry, num_workers=1, delay=0.5,
                 check_worker_health=False,
                 health_check_interval=1):

        self.workers = num_workers
        self.delay = delay
        self.task_queue = task_queue
        self.registry = registry

        # Configure health-check and consumer main-loop attributes.
        self._stop_flag_timeout = 0.1
        self._health_check = check_worker_health
        self._health_check_interval = (float(health_check_interval) /
                                       self._stop_flag_timeout)
        self.__health_check_counter = 0

        self.stop_flag = self.get_stop_flag()

        # Create the worker threads
        self.worker_threads = []
        for i in range(self.workers):
            worker = self._create_worker()
            thread_id = self._create_worker_thread(worker, 'Worker-%d' % (i + 1))
            self.worker_threads.append((worker, thread_id))


    def _create_worker(self):
        return Worker(task_queue=self.task_queue, task_registry=self.registry, sleep_interval=self.delay)

    def _create_worker_thread(self, worker_thread, name):
        def _run():
            try:
                while not self.stop_flag.is_set():
                    worker_thread.run()
            except Exception as ex:
                log.exception('Caught Exception %s in %s worker thread', str(ex), name)
        return self.create_thread(_run, name)

    def start(self):
        log.info('Task Executor started with %s' % (
            self.workers))

        for _, worker in self.worker_threads:
            log.debug("Starting Worker thread ")
            worker.start()

    def stop(self):
        self.stop_flag.set()
        log.info('Shutting down task executor')

    def check_worker_health(self):
        log.debug('Checking worker health.')
        workers = []
        restart_occurred = False
        for i, (worker, worker_t) in enumerate(self.worker_threads):
            if not self.is_alive(worker_t):
                log.warning('Worker %d died, restarting.' % (i + 1))
                worker = self._create_worker()
                worker_t = self._create_worker_thread(worker, 'Worker-%d' % (i + 1))
                worker_t.start()
                restart_occurred = True
            workers.append((worker, worker_t))

        if restart_occurred:
            self.worker_threads = workers
        else:
            log.debug('Workers are up and running.')

        return not restart_occurred

    def get_stop_flag(self):
        return threading.Event()

    def create_thread(self, runnable, name):
        t = threading.Thread(target=runnable, name=name)
        t.daemon = True
        return t

    def is_alive(self, proc):
        return proc.isAlive()
