# coding: utf-8
"""
    falcon_oauthlib.provider.oauth2
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    Implements OAuth2 provider support for Falcon.
"""

import logging
from functools import wraps

from .validator import OAuth2RequestValidator
from oauthlib import oauth2
from oauthlib.oauth2 import Server
import falcon
from .utils import extract_params, maybe_args, patch_response, redirect

__all__ = ('OAuth2Provider', 'OAuth2RequestValidator')

log = logging.getLogger('oauth_service')


class OAuth2Provider(object):
    """Provide secure services using OAuth2.

    The server should provide an authorize handler and a token hander,
    But before the handlers are implemented, the server should provide
    some getters for the validation.

    Configure :meth:`tokengetter` and :meth:`tokensetter` to get and
    set tokens. Configure :meth:`grantgetter` and :meth:`grantsetter`
    to get and set grant tokens. Configure :meth:`clientgetter` to
    get the client.

    Configure :meth:`usergetter` if you need password credential
    authorization.

    With everything ready, implement the authorization workflow:

        * :meth:`authorize_handler` for consumer to confirm the grant
        * :meth:`token_handler` for client to exchange access token

    And now you can protect the resource with scopes::

        @app.route('/api/user')
        @oauth.require_oauth('email', 'username')
        def user():
            return jsonify(request.oauth.user)
    """

    def __init__(self, app=None):
        self._before_request_funcs = []
        self._after_request_funcs = []
        self._invalid_response = None
        if app:
            self.init_app(app)

    def init_app(self, app):
        """
        This callback can be used to initialize an application for the
        oauth provider instance.
        """
        self.app = app
        app.extensions = getattr(app, 'extensions', {})
        app.extensions['oauthlib.provider.oauth2'] = self

    def error_uri(self):
        """The error page URI.

        When something turns error, it will redirect to this error page.
        You can configure the error page URI with config::

            OAUTH2_PROVIDER_ERROR_URI = '/error'

        You can also define the error page by a named endpoint::

            OAUTH2_PROVIDER_ERROR_ENDPOINT = 'oauth.error'
        """
        error_uri = self.app.config.get('OAUTH2_PROVIDER_ERROR_URI')
        if error_uri:
            return error_uri
        error_endpoint = self.app.config.get('OAUTH2_PROVIDER_ERROR_ENDPOINT')
        if error_endpoint:
            return error_endpoint
        return '/oauth/errors'

    def server(self):
        """
        All in one endpoints. This property is created automatically
        if you have implemented all the getters and setters.

        However, if you are not satisfied with the getter and setter,
        you can create a validator with :class:`OAuth2RequestValidator`::

            class MyValidator(OAuth2RequestValidator):
                def validate_client_id(self, client_id):
                    # do something
                    return True

        And assign the validator for the provider::

            oauth._validator = MyValidator()
        """
        expires_in = self.app.config.get('OAUTH2_PROVIDER_TOKEN_EXPIRES_IN')
        token_generator = self.app.config.get(
            'OAUTH2_PROVIDER_TOKEN_GENERATOR', None
        )
        #if token_generator and not callable(token_generator):
        #    token_generator = import_string(token_generator)

        refresh_token_generator = self.app.config.get(
            'OAUTH2_PROVIDER_REFRESH_TOKEN_GENERATOR', None
        )
        #if refresh_token_generator and not callable(refresh_token_generator):
        #    refresh_token_generator = import_string(refresh_token_generator)

        if hasattr(self, '_validator'):
            return Server(
                self._validator,
                token_expires_in=expires_in,
                token_generator=token_generator,
                refresh_token_generator=refresh_token_generator,
            )

        if hasattr(self, '_clientgetter') and \
                hasattr(self, '_tokengetter') and \
                hasattr(self, '_tokensetter') and \
                hasattr(self, '_grantgetter') and \
                hasattr(self, '_grantsetter'):

            usergetter = None
            if hasattr(self, '_usergetter'):
                usergetter = self._usergetter

            validator = OAuth2RequestValidator(
                clientgetter=self._clientgetter,
                tokengetter=self._tokengetter,
                grantgetter=self._grantgetter,
                usergetter=usergetter,
                tokensetter=self._tokensetter,
                grantsetter=self._grantsetter,
            )
            self._validator = validator
            return Server(
                validator,
                token_expires_in=expires_in,
                token_generator=token_generator,
                refresh_token_generator=refresh_token_generator,
            )
        raise RuntimeError('application not bound to required getters')

    def before_request(self, f):
        """Register functions to be invoked before accessing the resource.

        The function accepts nothing as parameters, but you can get
        information from `Flask.request` object. It is usually useful
        for setting limitation on the client request::

            @oauth.before_request
            def limit_client_request():
                client_id = request.values.get('client_id')
                if not client_id:
                    return
                client = Client.get(client_id)
                if over_limit(client):
                    return abort(403)

                track_request(client)
        """
        self._before_request_funcs.append(f)
        return f

    def after_request(self, f):
        """Register functions to be invoked after accessing the resource.

        The function accepts ``valid`` and ``request`` as parameters,
        and it should return a tuple of them::

            @oauth.after_request
            def valid_after_request(valid, oauth):
                if oauth.user in black_list:
                    return False, oauth
                return valid, oauth
        """
        self._after_request_funcs.append(f)
        return f

    def invalid_response(self, f):
        """Register a function for responsing with invalid request.

        When an invalid request proceeds to :meth:`require_oauth`, we can
        handle the request with the registered function. The function
        accepts one parameter, which is an oauthlib Request object::

            @oauth.invalid_response
            def invalid_require_oauth(req):
                return jsonify(message=req.error_message), 401

        If no function is registered, it will return with ``abort(401)``.
        """
        self._invalid_response = f
        return f

    def clientgetter(self, f):
        """Register a function as the client getter.

        The function accepts one parameter `client_id`, and it returns
        a client object with at least these information:

            - client_id: A random string
            - client_secret: A random string
            - client_type: A string represents if it is `confidential`
            - redirect_uris: A list of redirect uris
            - default_redirect_uri: One of the redirect uris
            - default_scopes: Default scopes of the client

        The client may contain more information, which is suggested:

            - allowed_grant_types: A list of grant types
            - allowed_response_types: A list of response types
            - validate_scopes: A function to validate scopes

        Implement the client getter::

            @oauth.clientgetter
            def get_client(client_id):
                client = get_client_model(client_id)
                # Client is an object
                return client
        """
        self._clientgetter = f
        return f

    def usergetter(self, f):
        """Register a function as the user getter.

        This decorator is only required for **password credential**
        authorization::

            @oauth.usergetter
            def get_user(username, password, client, request,
                         *args, **kwargs):
                # client: current request client
                if not client.has_password_credential_permission:
                    return None
                user = User.get_user_by_username(username)
                if not user.validate_password(password):
                    return None

                # parameter `request` is an OAuthlib Request object.
                # maybe you will need it somewhere
                return user
        """
        self._usergetter = f
        return f

    def tokengetter(self, f):
        """Register a function as the token getter.

        The function accepts an `access_token` or `refresh_token` parameters,
        and it returns a token object with at least these information:

            - access_token: A string token
            - refresh_token: A string token
            - client_id: ID of the client
            - scopes: A list of scopes
            - expires: A `datetime.datetime` object
            - user: The user object

        The implementation of tokengetter should accepts two parameters,
        one is access_token the other is refresh_token::

            @oauth.tokengetter
            def bearer_token(access_token=None, refresh_token=None):
                if access_token:
                    return get_token(access_token=access_token)
                if refresh_token:
                    return get_token(refresh_token=refresh_token)
                return None
        """
        self._tokengetter = f
        return f

    def tokensetter(self, f):
        """Register a function to save the bearer token.

        The setter accepts two parameters at least, one is token,
        the other is request::

            @oauth.tokensetter
            def set_token(token, request, *args, **kwargs):
                save_token(token, request.client, request.user)

        The parameter token is a dict, that looks like::

            {
                u'access_token': u'6JwgO77PApxsFCU8Quz0pnL9s23016',
                u'token_type': u'Bearer',
                u'expires_in': 3600,
                u'scope': u'email address'
            }

        The request is an object, that contains an user object and a
        client object.
        """
        self._tokensetter = f
        return f

    def grantgetter(self, f):
        """Register a function as the grant getter.

        The function accepts `client_id`, `code` and more::

            @oauth.grantgetter
            def grant(client_id, code):
                return get_grant(client_id, code)

        It returns a grant object with at least these information:

            - delete: A function to delete itself
        """
        self._grantgetter = f
        return f

    def grantsetter(self, f):
        """Register a function to save the grant code.

        The function accepts `client_id`, `code`, `request` and more::

            @oauth.grantsetter
            def set_grant(client_id, code, request, *args, **kwargs):
                save_grant(client_id, code, request.user, request.scopes)
        """
        self._grantsetter = f
        return f

    def authorize_handler(self, f):
        """Authorization handler decorator.

        This decorator will sort the parameters and headers out, and
        pre validate everything::

            @app.route('/oauth/authorize', methods=['GET', 'POST'])
            @oauth.authorize_handler
            def authorize(*args, **kwargs):
                if request.method == 'GET':
                    # render a page for user to confirm the authorization
                    return render_template('oauthorize.html')

                confirm = request.form.get('confirm', 'no')
                return confirm == 'yes'
        """

        @wraps(f)
        def decorated(req, resp, *args, **kwargs):
            # raise if server not implemented
            server = self.server()
            uri, http_method, body, headers = extract_params(req)

            redirect_uri = req.params.get('redirect_uri', self.error_uri)
            log.debug('Found redirect_uri %s.', redirect_uri)
            if req.method in ('GET', 'HEAD'):
                try:
                    ret = server.validate_authorization_request(
                        uri, http_method, body, headers
                    )
                    scopes, credentials = ret
                    kwargs['scopes'] = scopes
                    kwargs.update(credentials)
                except oauth2.FatalClientError as e:
                    log.debug('Fatal client error %r', e)
                    resp.status = falcon.HTTP_SEE_OTHER
                    resp.headers['Location'] = redirect_uri
                except oauth2.OAuth2Error as e:
                    log.debug('OAuth2Error: %r', e)
                    resp.status = falcon.HTTP_SEE_OTHER
                    resp.headers['Location'] = redirect_uri
                else:
                    try:
                        rv = f(*args, **kwargs)
                    except oauth2.FatalClientError as e:
                        log.debug('Fatal client error %r', e)
                        resp.status = falcon.HTTP_SEE_OTHER
                        resp.headers['Location'] = redirect_uri
                    except oauth2.OAuth2Error as e:
                        log.debug('OAuth2Error: %r', e)
                        resp.status = falcon.HTTP_SEE_OTHER
                        resp.headers['Location'] = redirect_uri
                    else:
                        if rv:
                            if not isinstance(rv, bool):
                                resp.body = rv
                            else:
                                self.confirm_authorization_request(req, resp)
                        else:
                            # denied by user
                            e = oauth2.AccessDeniedError()
                            log.debug('OAuth2Error: %r', e)
                            resp.status = falcon.HTTP_SEE_OTHER
                            resp.headers['Location'] = redirect_uri

        return decorated

    def confirm_authorization_request(self, req, resp):
        """When consumer confirm the authorization."""
        server = self.server()
        scope = req.params.get('scope') or ''
        scopes = scope.split()
        credentials = dict(
            client_id=req.params.get('client_id'),
            redirect_uri=req.params.get('redirect_uri', None),
            response_type=req.params.get('response_type', None),
            state=req.params.get('state', None)
        )
        log.debug('Fetched credentials from request %r.', credentials)
        redirect_uri = credentials.get('redirect_uri')
        log.debug('Found redirect_uri %s.', redirect_uri)

        uri, http_method, body, headers = extract_params(req)
        try:
            headers, body, status = server.create_authorization_response(
                uri, http_method, body, headers, scopes, credentials)
            log.debug('Authorization successful.')
        except oauth2.FatalClientError as e:
            log.debug('Fatal client error %r', e)
            redirect(e.in_uri(self.error_uri))
        except oauth2.OAuth2Error as e:
            log.debug('OAuth2Error: %r', e)
            redirect(e.in_uri(redirect_uri or self.error_uri))
        else:
            patch_response(resp, headers, body, status)

    def verify_request(self, req, scopes):
        """Verify current request, get the oauth data.

        If you can't use the ``require_oauth`` decorator, you can fetch
        the data in your request body::

            Class YourResource:

                def on_get(self, req, resp):
                    valid, oauth_req = oauth.verify_request(req, ['email'])
                    if valid:
                        return jsonify(user=oauth_req.user)
                    return jsonify(status='error')
        """
        uri, http_method, body, headers = extract_params(req)
        return self.server().verify_request(
            uri, http_method, body, headers, scopes
        )

    def token_endpoint(self, method):
        """Access/refresh token handler decorator.

        The decorated function should return an dictionary or None as
        the extra credentials for creating the token response.

            Class Token:

                @auth.token_endpoint
                def on_get(self, req, resp):
                    return None
            app.add_route('/auth/token', Token())
        """

        @wraps(method)
        def decorated(resource, req, resp, *args, **kwargs):
            server = self.server()
            uri, http_method, body, headers = extract_params(req)
            credentials = method(resource, req, resp, *args, **kwargs) or {}
            log.debug('Fetched extra credentials, %r.', credentials)
            headers, body, status = server.create_token_response(
                uri, http_method, body, headers, credentials
            )
            patch_response(resp, headers, body, status)

        return decorated

    def register_endpoint(self, method):
        """Access/refresh token handler decorator.

        The decorated function should return an dictionary or None as
        the extra credentials for creating the token response.

            Class Token:

                @auth.token_endpoint
                def on_get(self, req, resp):
                    return None
            app.add_route('/auth/token', Token())
        """

        @wraps(method)
        def decorated(resource, req, resp, *args, **kwargs):
            server = self.server()
            client = method(resource, req, resp, *args, **kwargs)
            if not client:
                return
            credentials = {}
            uri, http_method, body, headers = extract_params(req)
            import base64
            headers["AUTHORIZATION"] = "Basic "+base64.b64encode(client.client_id+":"+client.client_secret)
            import urllib.request, urllib.parse, urllib.error
            body = urllib.parse.urlencode({
                "grant_type": "password",
                "username": client.user_id,
                "password": "dummy_password"
            })
            log.debug('Fetched extra credentials, %r.', credentials)
            headers, body, status = server.create_token_response(
                uri, http_method, body, headers, credentials
            )
            import json
            map = json.loads(body)
            map["client_id"] = client.client_id
            map["client_secret"] = client.client_secret
            body = json.dumps(map)
            patch_response(resp, headers, body, status)

        return decorated

    def revoke_endpoint(self, method):
        """Access/refresh token revoke decorator.

        Any return value by the decorated function will get discarded as
        defined in [`RFC7009`_].

        As per [`RFC7009`_] it is recommended to only allow
        the `POST` method::

            Class RevokeToken:

                @auth.revoke_endpoint
                def on_get(self, req, resp):
                    pass
            app.add_route('/auth/revoke', RevokeToken())

        .. _`RFC7009`: http://tools.ietf.org/html/rfc7009
        """

        @wraps(method)
        def decorated(resource, req, resp, *args, **kwargs):
            server = self.server()

            # token = req.get_param('token')
            # req.context['token_type_hint'] = req.get_param('token_type_hint')
            # if token:
            #     req.context['token'] = token

            uri, http_method, body, headers = extract_params(req)
            headers, body, status = server.create_revocation_response(
                uri, headers=headers, body=body, http_method=http_method)
            return patch_response(resp, headers, body, status)

        return decorated

    @maybe_args
    def protect(self, method, *scopes):
        """Protect resource with specified scopes."""

        @wraps(method)
        def decorated(resource, req, resp, *args, **kwargs):
            # if req.context.get('oauth'):
            #     return method(resource, req, resp, *args, **kwargs)

            valid, oauth_req = self.verify_request(req, scopes)

            if not valid:
                if hasattr(self, "_on_error"):
                    return self._on_error(req, resp)
                scheme = 'Bearer realm="{}"'.format(' '.join(scopes) or '*')
                raise falcon.HTTPUnauthorized('Auth required', 'Auth Required', scheme=scheme)
            # req.context['oauth'] = oauth_req
            return method(resource, req, resp, *args, **kwargs)

        return decorated
