#-----------------------------------------------------
#
# Copyright (c) 2012-2013 by Cisco Systems, Inc.
# All rights reserved.
#-----------------------------------------------------

'''
Created on Apr 29, 2012

@author: rnethi
'''

from threading import Thread, Event
import logging
import falcon
import os
import time
from ..utils.utils import Utils
from .auth import check_auth_token, check_outh_token_validator, check_iox_token, check_child_auth_token
from .common import AuthenticatedResource, OauthResourceValidator, IOXResourceValidator, ChildResourceValidator
from ..runtime.hostingmgmt import HostingManager
from cheroot import server
from cheroot import wsgi
from .common import flush_request
from .api_middleware.multipart_middleware import MultipartMiddleware
from .api_middleware.accounting_middleware import AccountingMiddleware
from .api_middleware.compression_middleware import CompressionMiddleware
from .api_middleware.ratelimiter_middleware import RateLimiterMiddleware
from .api_middleware.security_middleware import SecurityMiddleware
from appfw.api.systeminfo import SystemInfo
from appfw.api.certutil import *
from appfw.utils.commandwrappers import openssl
import http.client

log = logging.getLogger("runtime.api")

def get_client_address(environ):
    try:
        return environ['HTTP_X_FORWARDED_FOR'].split(',')[-1].strip()
    except KeyError:
        return environ['REMOTE_ADDR']

class CustomWSGIServer(wsgi.Server):
    """
    Class that overrides the error_log method so that
    wsgi server errors also appear in the caf.log file
    """
    def error_log(self, msg="", level=20, traceback=False):

        if traceback:
            log.exception(msg)
        else:
            log.error(msg)
            
class SecondaryWSGIServer(wsgi.Server):
    """
    Class that overrides wsgi.Server to use as secondary server
    """
    def error_log(self, msg="", level=20, traceback=False):

        if traceback:
            log.exception(msg)
        else:
            log.error(msg)

class ResourceRoute(object):
    """
    class decorator to automatically register the route
    associated with a REST resource
    """
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        
    def __call__ (self, cls):
        # If cls is a sub class of authenticated resource, add a "before" hook to check token
        if issubclass(cls, AuthenticatedResource):
            falcon.hooks.before(check_auth_token)(cls)
        if issubclass(cls, OauthResourceValidator):
            falcon.hooks.before(check_outh_token_validator)(cls)
        if issubclass(cls, IOXResourceValidator):
            falcon.hooks.before(check_iox_token)(cls)
        if issubclass(cls, ChildResourceValidator):
            falcon.hooks.before(check_child_auth_token)(cls)


        APIService.instance.application.add_route(APIService.API_PATH_PREFIX + self.args[0], cls())
        if APIService.sec_instance and APIService.sec_instance.application:
            APIService.sec_instance.application.add_route(APIService.API_PATH_PREFIX + self.args[0], cls())

        # Store the route in a map for logging purpose
        APIService.ROUTE_MAP[APIService.API_PATH_PREFIX + self.args[0]] = str(cls)
        return cls


class ResourceSink(object):
    """
    class decorator to automatically register the sink
    associated with a REST resource
    """
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    def __call__ (self, cls):
        APIService.instance.application.add_sink(cls(), APIService.API_PATH_PREFIX + self.args[0])
        if APIService.sec_instance and APIService.sec_instance.application:
            APIService.sec_instance.application.add_sink(cls(), APIService.API_PATH_PREFIX + self.args[0])

        # Store the route in a map for logging purpose
        APIService.ROUTE_MAP[APIService.API_PATH_PREFIX + self.args[0]] = str(cls)
        return cls


class APIService(Thread):
    instance = None
    sec_instance = None
    API_PATH_PREFIX = "/iox/api/v2/hosting"
    ROUTE_MAP = {}
    
    prim_server_event = Event()

    def __init__(self, config, is_secondary_server=False):
        Thread.__init__(self)
        # Before API service is instantiated, hosting manager is expected to be started
        self._hosting_manager = HostingManager.get_instance()
        self._app_manager = self._hosting_manager.get_service("app-management")
        self._cartridge_manager = self._hosting_manager.get_service("cartridge-management")
        self._network_manager = self._hosting_manager.get_service("network-management")
        self._device_manager = self._hosting_manager.get_service("device-management")
        self._console_service = self._hosting_manager.get_service("console-management")
        self._scp_service = self._hosting_manager.get_service("scp-management")
        self._task_service = self._hosting_manager.get_service("task-service")
        self._push_service = self._hosting_manager.get_service("push-service")
        self._autoconfigcli_manager = self.hosting_manager.get_service("autoconfigcli-management")
        self._bist_service = self._hosting_manager.get_service("bist-service")
        self._layer_reg = self._hosting_manager.get_service("layer_reg_service")
        self._smartlicense = self._hosting_manager.get_service("smartlicense")
        self._datastore_service = self._hosting_manager.get_service("datastore_service")
        self._is_secondary_server = is_secondary_server
        self._update_manager = self._hosting_manager.get_service("update-management")
        self. ldevid_enabled = False

        self._config = config
        self._app = None
        self._port = None
        self._address = None
        self.server = None
        self.login_fail_attempt_ts = None
        self.login_fail_attempt_cnt = 0
        self.max_login_fail_attempt = 3
        self.login_fail_timeout = 60
        self._cipher = None
        
        self._set_app()

        #setDaemon is needed so that apiservice terminates properly 
        #as it hangs in some cases
        #http://stackoverflow.com/questions/1635080/terminate-a-multi-thread-python-program
        self.setDaemon(True)

        if self._config.has_section("authentication"):
            if self._config.has_option("authentication", "max_failed_login_cnt"):
                self.max_login_fail_attempt = config.getint("authentication", "max_failed_login_cnt")
            if self._config.has_option("authentication", "failed_login_timeout"):
                self.login_fail_timeout = config.getint("authentication", "failed_login_timeout")

        #Default value of api_path_prefix is /iox/api/v2/hosting 
        if self._config.has_option("api", "api_path_prefix"):
            APIService.API_PATH_PREFIX = config.get("api", "api_path_prefix")
            
    def _set_app(self):
        api_middleware_list = []
        vis_section = Utils.getSystemConfigSection("visulization")
        datastore_section = Utils.getSystemConfigSection("datastore")
        if (vis_section is not None and vis_section.get("enabled", "no") == "yes") or (datastore_section is not None and datastore_section.get("enabled", "no") == "yes"):
            api_middleware_list.append(RateLimiterMiddleware())
        api_middleware_list.append(AccountingMiddleware())
        api_middleware_list.append(MultipartMiddleware())
        if Utils.getSystemConfigValue("api", "enable_compression", True, "bool"):
            api_middleware_list.append(CompressionMiddleware())
        api_middleware_list.append(SecurityMiddleware())
        self._app = falcon.API(middleware=api_middleware_list)

    @property
    def is_secondary_server(self):
        return self._is_secondary_server       
            
    @property
    def application(self):
        return self._app
        
    @property
    def config(self):
        return self._config

    @property
    def app_manager(self):
        return self._app_manager

    @property
    def network_manager(self):
        return self._network_manager

    @property
    def bist_service(self):
        return self._bist_service

    @property
    def device_manager(self):
        return self._device_manager

    @property
    def cartridge_manager(self):
        return self._cartridge_manager

    @property
    def console_service(self):
        return self._console_service

    @property
    def scp_service(self):
        return self._scp_service

    @property
    def hosting_manager(self):
        return self._hosting_manager

    @property
    def task_service(self):
        return self._task_service

    @property
    def smart_license(self):
        return self._smartlicense

    @property
    def push_service(self):
        return self._push_service

    @property
    def layer_reg_serv(self):
        return self._layer_reg

    @property
    def datastore_service(self):
        return self._datastore_service

    @property
    def api_port(self):
        return self._port

    @property
    def api_ip_address(self):
        return self._address

    @property
    def autoconfigcli_manager(self):
        return self._autoconfigcli_manager

    @property
    def update_manager(self):
        return self._update_manager

    def increment_failed_login_cnt(self):
        self.login_fail_attempt_cnt += 1

    def _enable_local_manager(self, apiAddress, apiPort):
        log.debug("Enabling local manager")
        from ..mgmt.webapp import register_localmanager
        from ..utils.utils import Utils

        local_manager_root = Utils.getLocalManagerRootFolder()
        staticFolder = os.path.join(local_manager_root, 'static')
        templateFolder = os.path.join(local_manager_root, 'templates')
        register_localmanager(self, apiAddress, apiPort, staticFolder, templateFolder)

    def add_route(self, uri_template, resource):
        self.ROUTE_MAP[uri_template] = str(resource.__class__)
        self.application.add_route(uri_template, resource)

    def add_sink(self, sink, prefix="/"):
        self.ROUTE_MAP[prefix] = str(sink.__class__)
        self.application.add_sink(sink, prefix)

    def run(self):
        log.debug("Starting CAF webserver")

        prim_apiAddress = self._config.get("api", "address")
        prim_apiPort = self._config.getint("api", "port")
            
        if self._is_secondary_server:
            # We are in secondary server thread. Wait for primary server start
            self.prim_server_event.wait()
            apiAddress = self._config.get("api", "secondary_server_address")
            apiPort = self._config.getint("api", "secondary_server_port")
        else:
            apiAddress = prim_apiAddress
            apiPort = prim_apiPort
            
        self._port = str(apiPort)
        self._address = apiAddress

        #load all the REST Resources - this forces the route registration
        try:
            from . import allresources
            from .bare_resource import BareResource
        except Exception as ex:
            log.exception("Failed to load resources : %s" % str(ex))
            return

        log.debug("Imported all API Resources")
        # Check if LM has to be enabled
        enable_local_manager = True
        try:
            if self._config.has_option("api", "enable_local_manager"):
                enable_local_manager = self._config.getboolean("api", "enable_local_manager")
        except Exception as ex:
            log.exception("Exception reading local manager enable setting. Assuming it to be true.")
            enable_local_manager = True

        if enable_local_manager:
            self._enable_local_manager(prim_apiAddress, prim_apiPort)

        # Update ROUTE_MAP
        def _update_route_map():
            for uri in list(self.ROUTE_MAP.keys()):
                node = self.application._router.find(uri)
                if node:
                    # Update with allowed and not allowed methods.
                    resource_handler = self.ROUTE_MAP[uri]
                    allowed_methods = []
                    for k, v in list(node[1].items()):
                        if "not_allowed" in str(v):
                            pass
                        else:
                            allowed_methods.append(str(k))

                    self.ROUTE_MAP[uri] = [resource_handler, allowed_methods]

        _update_route_map()
        #Adding api bare resource version to determine API version supported by CAF
        self.add_route("/iox", BareResource())
        self.add_route("/iox/version", BareResource())
        log.debug("Routes:\n%s" % "\n".join([str(k) for k in list(self.ROUTE_MAP.items())]))
        log.info("Serving following URLs : \n%s" % "\n".join([str(k) for k in list(self.ROUTE_MAP.keys())]))

        # clear ROUTE_MAP and free memory
        self.ROUTE_MAP = None

        use_ssl = True
        if self._config.has_option("api", "use_ssl"):
            use_ssl = self._config.getboolean("api", "use_ssl")

        if use_ssl and not self._is_secondary_server:
            log.info("SSL is turned on. Will listen on https and use SSL")
            if self._config.has_option("api", "cipher"):
               self._cipher = self._config.get("api", "cipher")
            log.debug("SSL cipher to use %s",self._cipher)

            sslModule = "builtin"
            if self._config.has_option("api", "ssl_module"):
                temp = self._config.get("api", "ssl_module")
                if temp != "builtin" and temp != "pyopenssl":
                    log.warning("Invalid SSL module from configuration: %s, will use builtin SSL implementation", temp)
                else:
                    sslModule = temp
            log.info("Using ssl module: %s", sslModule)
            self. ldevid_enabled = False
            if self._config.has_option("api", "use_ldevid"):
                self. ldevid_enabled = self._config.get("api", "use_ldevid")

            adapterClass = server.get_ssl_adapter_class(sslModule)

            ssl_key_cert = self.get_cert_and_key(sslModule)

            if sslModule == "builtin":
                #sslAdapterInstance = adapterClass(sslCertificate, sslPrivateKey, None, self._cipher)
                sslAdapterInstance = adapterClass(ssl_key_cert["cert_path"], ssl_key_cert["key_path"], None, self._cipher)
            else: 
                if not self.ldevid_enabled or self.ldevid_enabled == "no":
                    sslAdapterInstance = adapterClass(ssl_key_cert["cert_path"], ssl_key_cert["key_path"], None, self._cipher)
                else:
                    certbuf = ssl_key_cert["cert_buf"]
                    keybuf = ssl_key_cert["key_buf"]
                    if certbuf == "" or keybuf == "":
                        log.info("Will use the default key and certificate")
                        sslCertificate = self._config.get("api", "sslcertificate")
                        sslPrivateKey = self._config.get("api", "sslprivatekey")
                        with open(sslCertificate, "rb") as sslcert:
                            certbuf = sslcert.read().strip()
                        with open(sslPrivateKey, "rb") as sslkey:
                            keybuf = sslkey.read().strip()

                    #log.debug("Certificate: %s"  % certbuf)
                    #log.debug("Key: %s" % keybuf)
                    sslAdapterInstance = adapterClass(certbuf, keybuf, None, self._cipher)
            CustomWSGIServer.ssl_adapter = sslAdapterInstance
        else:
            log.info("SSL is disabled. Will listen on plain http")

        num_threads = 4
        if self._config.has_option("api", "num_threads"):
            num_threads = self._config.getint("api", "num_threads")

        timeout = 10
        if self._config.has_option("api", "timeout"):
            timeout = self._config.getint("api", "timeout")

        use_internal_webserver = True
        if self._config.has_option("api", "use_internal_webserver"):
            use_internal_webserver = self._config.getboolean("api", "use_internal_webserver")

        if use_internal_webserver:
            d = wsgi.PathInfoDispatcher({'/': self._app})
            # Use customwsgi so that we can see server errors in caf logs
            if not self._is_secondary_server:
                self.server = CustomWSGIServer((apiAddress, apiPort), d, numthreads=num_threads, request_queue_size=10, timeout=timeout)
                self.prim_server_event.set()
            else:
                self.server = SecondaryWSGIServer((apiAddress, apiPort), d, numthreads=num_threads, request_queue_size=10, timeout=timeout)
                
            try:
                log.info("Starting web server with ip %s on port %s", apiAddress, apiPort)
                self.server.start()
            except KeyboardInterrupt:
                self.server.stop()
                self.server = None
        else:
            #Bind to the fcgi socket
            from flup.server.fcgi import WSGIServer 
            fcgi_bind_address = "/tmp/rp/lipc/iox_fcgi_socket"
            if self._config.has_option("api", "fcgi_socket_bind_address"):
                fcgi_bind_address = self._config.get("api", "fcgi_socket_bind_address")
            WSGIServer(application = self._app, bindAddress = fcgi_bind_address).run()

    def get_ss_cert_and_key(self, certname, keyname):
        """
        Returns certificate and keys from the secure server
        """
        cert_key_dict = {}
        secure_server_enabled = "no"
        if self._config.has_option("secure_storage_server", "enabled"):
            secure_server_enabled = self._config.get("secure_storage_server", "enabled")
        if secure_server_enabled != "yes":
            return cert_key_dict
        encrypted_dir = self._config.get("secure_storage_server", "encrypted_dir")
        ss_cert = os.path.join(encrypted_dir, "certs", certname)
        ss_privkey = os.path.join(encrypted_dir, "certs", keyname)
        if os.path.exists(ss_cert) and os.path.exists(ss_privkey):
            log.info("Using key and certificates stored in Secure Storage")
            with open(ss_cert, "rb") as sslcert:
                sscertbuf = sslcert.read().strip()
            with open(ss_privkey, "rb") as sslkey:
                sskeybuf = sslkey.read().strip()

            cert_key_dict["cert_path"] = ss_cert
            cert_key_dict["key_path"] = ss_privkey
            cert_key_dict["cert_buf"]=sscertbuf
            cert_key_dict["key_buf"]=sskeybuf
            log.debug("Certificate from SS:%s" % cert_key_dict["cert_buf"])
            #log.debug("Key from SS:%s" % cert_key_dict["key_buf"])
            return cert_key_dict
        else:
            log.error("Certificate and keys not found in secure storage")
            return None
            

    
    def get_cert_and_key(self, sslModule):
        """
        Returns the certificate path, cert buffer, key path, key buffer as dict
        Will use the certifcate from the SS storage if it is not self signed
        Otherwise try to get from ldevid
        """
        # Update the ssl cert paths if secure storage is enabled
        certbuf=""
        keybuf=""
        sscertbuf=""
        sskeybuf=""
        use_ldevid=False
        cert_key = {}
        cert_key["key_buf"] = ""
        cert_key["cert_buf"]= ""
        sslCertificate = self._config.get("api", "sslcertificate")
        sslPrivateKey = self._config.get("api", "sslprivatekey")
        if SystemInfo.is_secure_storage_supported() and SystemInfo.is_secure_storage_server_up_and_available():
            secure_server_enabled = "no"
            if self._config.has_option("secure_storage_server", "enabled"):
                secure_server_enabled = self._config.get("secure_storage_server", "enabled")
            # Check if secure storage is enabled 
            if secure_server_enabled == "yes":
                cert_key = self.get_ss_cert_and_key(os.path.basename(sslCertificate),  os.path.basename(sslPrivateKey))
                if not cert_key:
                    cert_key["cert_path"]=sslCertificate
                    cert_key["key_path"]=sslPrivateKey
                    return cert_key

                if not self.ldevid_enabled or self.ldevid_enabled == "no":
                    log.debug("Ldevid is disabled explicitly")
                    return cert_key

                use_ldevid = False
                if sslModule != "builtin":
                    #Verify if this is a self signed certificate
                    ss_cert=cert_key["cert_path"]    
                    try :
                        certhelper = CertHelper.getInstance(self._config)
                        subject = certhelper.getSubject(ss_cert)
                        if certhelper.isSelfSignedCert(ss_cert):
                            if subject.endswith("unstructuredName=0.0.0.0"):
                                log.debug("Found caf generated self signed certifcate and will ignore for now")
                                use_ldevid = True
                    except Exception as ex:
                        log.exception("Failed to get cert details: %s" % ex)
                        use_ldevid = True

                if use_ldevid and sslModule == "pyopenssl":
                    host = "127.0.0.1"
                    port = "9443"
                    if self._config.has_option("secure_storage_server","host"):
                        host =  self._config.get("secure_storage_server","host")
                    if self._config.has_option("secure_storage_server", "port"):
                        port =  self._config.get("secure_storage_server","port")
                    try:
                        log.debug("Retrieving cert and keys from LDEVID")
                        token = Utils.get_http_response(host, port,  "/SS/token/caf/0")

                        # Get the udi cert
                        certbuf = Utils.get_http_response(host, port, "/SS/caf/udi?ss-Token="+token)
                        log.debug("Cert from LDEVID:%s" % certbuf)
                        # Get the udi key
                        keybuf = Utils.get_http_response(host, port, "/SS/caf/udikey?ss-Token="+token)
                        #log.debug("Key from LDEVID:%s" % keybuf)
                        log.info("Using cert and keys of LDEVID")
                        if not certbuf.startswith("-----BEGIN CERTIFICATE-----") :
                            log.info("Not able to get the LDEVID cert: %s" % certbuf)
                            raise Exception("Not able to get the LDEVID cert: %s" % certbuf)
                        if not keybuf.startswith("-----BEGIN") :
                            log.info("Not able to get the LDEVID key: %s" % keybuf)
                            raise Exception("Not able to get the LDEVID key: %s" % keybuf)
                        cert_key["key_buf"] = keybuf
                        cert_key["cert_buf"] = certbuf

                    except Exception as ex:
                        log.exception("Failed to get udi key or cert from the secure storage service: %s" % str(ex))
            else:
                log.debug("Secure server is disabled")
                cert_key["cert_path"]=sslCertificate
                cert_key["key_path"]=sslPrivateKey
        else:
            log.debug("Secure server is not available")
            cert_key["cert_path"]=sslCertificate
            cert_key["key_path"]=sslPrivateKey

        return cert_key


    @classmethod    
    def startService(cls, config):
        '''
        Start API Services
        '''
        if cls.instance is None:
            cls.instance = APIService(config)
                    
        sec_instance_enable = Utils.getSystemConfigValue("api", "enable_secondary_server", False, "bool")        
        if sec_instance_enable:
            if cls.sec_instance is None:
                cls.sec_instance = APIService(config, is_secondary_server=True)
                
        cls.instance.start()
        if cls.sec_instance:
            cls.sec_instance.start()
        
    @classmethod    
    def stopService(cls):
        '''
        Stop API services
        '''
        #TODO: Need to figure out how to gracefully stop this thread
        if cls.instance and cls.instance.server is not None:
            cls.instance.server.stop()
            # TODO: Setting the instance to none, so that c3 infra can be restarted when required
            # without getting "threads can only be started once
            cls.instance = None
            cls.prim_server_event.clear()
            
        if cls.sec_instance and cls.sec_instance.server is not None:
            cls.sec_instance.server.stop()
            cls.sec_instance = None

    def check_login_allowed(self, user=None):
        """
        Returns True if login is allowed
        Checks against the failed login attempts
        """
        log.debug("Last fail ts: %s, Fail attempt cnt:%s, timeout: %s,  Max fail cnt: %s, Current time: %s" , 
                  self.login_fail_attempt_ts, self.login_fail_attempt_cnt, 
                  self.login_fail_timeout, self.max_login_fail_attempt,
                  time.time())  
        if self.login_fail_attempt_ts is None:
            return True
        if time.time() > (self.login_fail_attempt_ts + self.login_fail_timeout) :
            return True
        if self.login_fail_attempt_cnt < self.max_login_fail_attempt:
            return True 
        return False
