__author__ = 'srinathc'

__all__ = ["OAuthService"]

import logging
import os
import dbm
import uuid
import datetime
import json
import base64
import urllib.request, urllib.parse, urllib.error
import yaml

from functools import partial
from .model import *
from .validator import *
from .oauth_provider import OAuth2Provider
from ..runtime.caf_abstractservice import CAFAbstractService

log = logging.getLogger("oauth_service")

def db_get(db, key):
    return db[key] if key in db else None

class OAuthService(CAFAbstractService):

    __singleton = None # the one, true Singleton
    OAUTH_SERVICE_API_MAP = {"clients": "/oauth/clients",
                             "get_client": "/oauth/clients/{client_id}",
                             "tokens": "/oauth/tokens",
                             "get_token": "/oauth/tokens/{access_token}",
                             "protected": "/oauth/protected",
                             "revoke": "/oauth/revoke",
                             "verify": "/oauth/verify",
                             "authorize": "/oauth/authorize"
                            }

    def __new__(cls, *args, **kwargs):
        # Check to see if a __singleton exists already for this class
        # Compare class typeson instead of just looking for None so
        # that subclasses will create their own __singleton objects
        if cls != type(cls.__singleton):
        #if not cls.__singleton:
            cls.__singleton = super().__new__(cls)
        return cls.__singleton

    @classmethod
    def getInstance(cls, *args):
        '''
        Returns a singleton instance of the class
        '''
        if cls.__singleton == None:
            cls.__singleton = OAuthService(*args)

        return cls.__singleton

    def __init__(self, params, persistent_store):
        log.info("Creating OAuth Service")
        self.name = params.name
        self.config = params.config
        self._config_file = params.config_file
        #self.enabled = self.config.get("enabled", False)

        log.info("persistent_store: %s", persistent_store)
        self.persistent_store = persistent_store
        self._user_db = None
        self._client_db = None
        self._grant_db = None
        self._access_token_db = None
        self._oauthValidator = None
        self._provider = None
        self._running = False
        log.info("oauth provider is created")

    @property
    def is_running(self):
        return self._running

    @property
    def enabled(self):
        try:
            rv = self.config.get("enabled", False)
        except Exception as ex:
            log.error("Error in parsing enabled attribute:%s" % str(ex))
            self.config["enabled"] = False
            self._save_data()
            rv = False
        return rv

    @property
    def expires_in(self):
        try:
            rv = int(self.config.get("token_expiration", "86400"))
        except Exception as ex:
            log.error("Error in parsing token_expiration attribute:%s" % str(ex))
            self.config["token_expiration"] = 86400
            self._save_data()
            rv = 86400
        return rv
        

    def start(self):
        if self.enabled and not self._running:
            log.info("Starting oauth service")
            user_file = os.path.join(self.persistent_store, "user")
            client_file = os.path.join(self.persistent_store, "client")
            grant_file = os.path.join(self.persistent_store, "grant")
            token_file = os.path.join(self.persistent_store, "token")
            try:
                self._user_db = dbm.open(user_file, "c")
            except:
                log.error("Oauth user config file is invalid/corrupted. Recreating user config file.")
                os.remove(user_file)
                self._user_db = dbm.open(user_file, "c")
            
            try:
                self._client_db = dbm.open(client_file, "c")
            except:
                log.error("Oauth client config file is invalid/corrupted. Recreating client config file.")
                os.remove(client_file)
                self._client_db = dbm.open(client_file, "c")
            
            try:
                self._grant_db = dbm.open(grant_file, "c")
            except:
                log.error("Oauth grant config file is invalid/corrupted. Recreating grant config file.")
                os.remove(grant_file)
                self._grant_db = dbm.open(grant_file, "c")
            
            try:
                self._access_token_db = dbm.open(token_file, "c")
            except:
               log.error("Oauth token config file is invalid/corrupted. Recreating token config file.")
               os.remove(token_file)
               self._access_token_db = dbm.open(token_file, "c")

            log.info("Oauth service: all dbs are created")

            self._oauthValidator = OAuth2RequestValidator(
                self.client_getter,
                self.token_getter,
                self.grant_getter,
                self.user_getter,
                self.token_setter,
                self.grant_setter,
                self.token_deleter
            )
            try:
                expires_in = int(self.config.get("token_expiration", "86400"))
            except Exception as ex:
                log.error("Error in parsing token_expiration attribute:%s" % str(ex))
                self.config["token_expiration"] = 86400
                self._save_data()
                expires_in = 86400

            log.info("token expires in %d", expires_in)
            self._oauthValidator.config['OAUTH2_PROVIDER_TOKEN_EXPIRES_IN'] = expires_in
            log.info("oauth validator is created")

            self._provider= OAuth2Provider()
            self._provider._validator = self._oauthValidator
            self._provider.init_app(self._oauthValidator)
            self._running = True

    def stop(self, forceful=True):
        """If forceful, terminate dispatcher thread immediately.
        Else, wait on queue so that all events are serviced and then exit"""

        log.debug("Stopping OAuth Service")
        self._user_db.close()
        self._client_db.close()
        self._access_token_db.close()
        self._grant_db.close()
        self._running = False

    def set_config(self, config):
        if self.validate_config(config):
            try:
                if self.list_clients():
                    log.error("To update the config of OAUTH service, first remove all clients installed %s"%self.list_clients())
                    raise ValueError("To update the config of OAUTH service, first remove all clients installed %s"%self.list_clients())
            except dbm.error:
                log.error("Seems like client is already closed, so assuming no clients are there!")
            try:
                if self.is_running:
                    self.stop()
            except Exception as ex:
                log.exception("OAUTH service stop failed, with reason: %s"%str(ex))
                raise Exception("OAUTH service stop failed, with reason: %s"%str(ex))

            self._update_config(config)
            try:
                if self.config.get("enabled", None):
                    self.start()
                else:
                    log.debug("OAUTH service is disabled as part of new config update!")
            except Exception as ex:
                log.exception("Error while setting up the OAUTH service with new config %s, cause: %s"%(config, str(ex)))
                self.stop()
                raise Exception("Error while setting up the OAUTH service with new config %s, cause: %s"%(config, str(ex)))
        else:
            log.error("Given config %s is invalid!"%config)
            raise ValueError("Given config %s is invalid!"%config)

    def get_config(self):
        return self.config

    def validate_config(self, config):
        log.debug("Validating the given config %s"%config)
        allowed_keys = list(self.config.keys())
        for key in list(config.keys()):
            if key not in allowed_keys:
                log.debug("Invalid key %s, has been found in new config"%key)
                return False
        return True

    def _update_config(self, config):
        if 'token_expiration' in list(config.keys()):
            self._oauthValidator.config['OAUTH2_PROVIDER_TOKEN_EXPIRES_IN'] = int(config["token_expiration"])
        self.config.update(config)
        self._save_data()

    def _save_data(self):
        """
        Save config file to disk. Default location is repo/running_config/.oauth. Will be in yaml format
        :return:
        """
        with open(self._config_file, "w") as f:
            yaml.safe_dump(self.config, f, default_flow_style=False)
            log.debug("Saved monitoring configuration to %s", self._config_file)

    @property
    def validator(self):
        return self._oauthValidator

    @property
    def provider(self):
        return self._provider


    def user_getter(self, username, *args, **kwargs):
        '''
            Retrieves the user from the store with the name passed in username
        '''
        user_str = db_get(self._user_db, str(username))
        return User.from_json(str(user_str, "utf-8")) if user_str else None

    def user_setter(self, username, password):
        '''
            Stores a user object in the store with the given username and password
        '''
        user = User(username, password)
        self._user_db[username] = user.to_json()
        #self._user_db.sync()
        return user

    def token_generator(self):
        '''
            Generates a new token ID. This method can be used to change the
            format of access and refresh tokens
        '''
        return uuid.uuid4()

    def get_client_by_appid(self, app_id):
        '''
            Retrieves a client object from the store which matches the given
            app_id
        '''
        if not app_id:
            return None
        client_str = db_get(self._client_db, str("ai:"+app_id))
        client = Client.from_json(str(client_str, "utf-8")) if client_str else None
        if client and client.user_id:
            client.user = self.user_getter(client.user_id)
        return client

    def client_getter(self, client_id):
        '''
            Retrieves a client object from the store which matches the given
            client_id
        '''
        if not client_id:
            return None
        client_str = db_get(self._client_db, str("ci:"+client_id))
        client = Client.from_json(str(client_str, "utf-8")) if client_str else None
        if client and client.user_id:
            client.user = self.user_getter(client.user_id)
        return client

    def list_clients(self):
        '''
            List all the clients from the store
        '''
        clients = []
        for key in list(self._client_db.keys()):
            client_str = self._client_db[key]
            #log.debug("Client key:%s" % client_str)
            if key.startswith("ai:".encode()) :
                client = Client.from_json(str(client_str, "utf-8"))
                clients.append(client)
        return clients

    def client_setter(self, client):
        '''
            Stores the given client object into the store
        '''
        self.get_client_by_appid(client.app_id)
        self._client_db[str("ci:"+client.client_id)] = client.to_json()
        self._client_db[str("ai:"+client.app_id)] = client.to_json()
        #self._client_db.sync()

    def client_deleter(self, client_id):
        '''
            Deletes a given client object from the store given a client_id
        '''
        client_str = db_get(self._client_db, str("ci:" + client_id))
        if client_str:
            client = Client.from_json(str(client_str,"utf-8"))
            del self._client_db[str("ci:" + client.client_id)]
            del self._client_db[str("ai:" + client.app_id)]
            #self._client_db.sync()

            tokens = {}
            for key in list(self._access_token_db.keys()):
                token_str = self._access_token_db[key]
                token = AccessToken.from_json(str(token_str, "utf-8"))
                if token.client_id == client.client_id:
                    tokens[token.access_token] = token
            for token in self.get_tokens_for_client(client.client_id):
                self.token_deleter(access_token=token.access_token)

    def get_tokens_for_client(self, client_id):
        '''
            Retrieve all tokens for a given client_id
        '''
        tokens = []
        for key in list(self._access_token_db.keys()):
            token_str = self._access_token_db[key]
            token = AccessToken.from_json(str(token_str, "utf-8"))
            if token.client_id == client_id:
                tokens.append(token)
        return tokens

    def list_tokens(self):
        '''
            List all the tokens in the system
        '''
        tokens = []
        for key in list(self._access_token_db.keys()):
            token_str = self._access_token_db[key]
            token = AccessToken.from_json(str(token_str, "utf-8"))
            tokens.append(token)
        return tokens

    def token_getter(self, access_token=None, refresh_token=None):
        '''
            retrieve a token object from the store given the access or refresh
            token. If both are specified, access token is retrieved.
        '''
        token_str = None
        if access_token:
            token_str = db_get(self._access_token_db, str("at:" + access_token))
        elif refresh_token:
            token_str = db_get(self._access_token_db, str("rt:" + refresh_token))
        if not token_str:
            return None
        token = AccessToken.from_json(str(token_str, "utf-8"))
        if token.user_id:
            token.user = self.user_getter(token.user_id)
        if token.client_id:
            token.client = self.client_getter(token.client_id)
        return token

    def token_setter(self, token, request, *args, **kwargs):
        '''
            Stores the given token into the store
        '''
        expires_in = token.get('expires_in')
        expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in) if expires_in >=0 else None
        log.debug("Access Token %s", token)
        tok = AccessToken(
            access_token=token['access_token'],
            refresh_token=token.get('refresh_token'),
            token_type=token['token_type'],
            scopes=token['scope'].split() if token['scope'] else [],
            expires=expires,
            client_id=request.client.client_id,
            user_id=request.user.username,
        )
        self._access_token_db[str("at:" + token["access_token"])] = tok.to_json()
        if token.get("refresh_token"):
            self._access_token_db[str("rt:" + token["refresh_token"])] = tok.to_json()
        #self._access_token_db.sync()

    def delete_tokens_for_app(self, app_id):
        '''
            Deletes all tokens belonging to the app.
        '''
        client = self.get_client_by_appid(app_id)
        tokens = self.get_tokens_for_client(client.client_id) if client else []
        for token in tokens:
            self.token_deleter(token.access_token)

    def token_deleter(self, access_token=None, refresh_token=None, *args, **kwargs):
        '''
            Deletes the given token into the store. Either an access token or a
            refresh token can be used. If both are specified, access_token is
            deleted
        '''
        token_str = None
        if access_token:
            token_str = db_get(self._access_token_db, str("at:" + access_token))
        elif refresh_token:
            token_str = db_get(self._access_token_db, str("rt:" + refresh_token))
        if token_str:
            token = AccessToken.from_json(str(token_str, "utf-8"))
            del self._access_token_db[str("at:" + token.access_token)]
            if token.refresh_token:
                del self._access_token_db[str("rt:" + token.refresh_token)]
            #self._access_token_db.sync()

    def grant_getter(self, client_id, code):
        '''
            Get a grant from the store
        '''
        key = json.dumps({"client_id": client_id, "code": code['code']})
        grant_str = db_get(self._grant_db, key)
        grant = Grant.from_json(str(grant_str, "utf-8")) if grant_str else None
        if grant and grant.user_id:
            grant.user = self.user_getter(grant.user_id)
        if grant and grant.client_id:
            grant.client = self.client_getter(grant.client_id)
        return grant

    def grant_setter(self, client_id, code, request, *args, **kwargs):
        '''
            Store a grant into the store
        '''
        expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100)
        grant = Grant(
            client_id=client_id,
            code=code['code'],
            redirect_uri=request.redirect_uri,
            scopes=' '.join(request.scopes),
            user=None,  # current_user(),
            expires=expires
        )
        key = json.dumps({"client_id": client_id, "code": code['code']})
        self._grant_db[key] = grant.to_json()
        #self._grant_db.sync()

    def register_app(self, app_id, user_id="dummy_user", access_scopes=[], default_scopes=[]):
        if self.is_running:
            client = self.get_client_by_appid(app_id)
            if not client:
                log.debug("Registering App with access scopes %s and default scopes %s", access_scopes, default_scopes)
                user = self.user_setter(user_id, "dummy_password")
                scopes = access_scopes if access_scopes else default_scopes
                grant_type = []
                client = Client(app_id, user.username, str(uuid.uuid4()), str(uuid.uuid4()), "confidential", [], scopes)
                # client.user = user
                self.client_setter(client)
            return client
        return None

    def unregister_app(self, app_id):
        client = self.get_client_by_appid(app_id)
        if client:
            self.client_deleter(client.client_id)
