diff --git a/changelog_entry.yaml b/changelog_entry.yaml index b98eef7e6..f1908ea51 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -1,4 +1,4 @@ - bump: patch changes: changed: - - Updated PolicyEngine US to 1.168.1. \ No newline at end of file + - API now checks for authenticated user, but only prints access errors rather than failing. diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 4b2f58090..78429e532 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -7,8 +7,11 @@ import flask import yaml from flask_caching import Cache +from authlib.integrations.flask_oauth2 import ResourceProtector +from policyengine_api.validator import Auth0JWTBearerTokenValidator from policyengine_api.utils import make_cache_key from .constants import VERSION +import policyengine_api.auth_context as auth_context # from werkzeug.middleware.profiler import ProfilerMiddleware @@ -40,6 +43,13 @@ app = application = flask.Flask(__name__) +## as per https://auth0.com/docs/quickstart/backend/python/interactive +require_auth = ResourceProtector() +validator = Auth0JWTBearerTokenValidator() +require_auth.register_token_validator(validator) + +auth_context.configure(app, require_auth=require_auth) + app.config.from_mapping( { "CACHE_TYPE": "RedisCache", diff --git a/policyengine_api/auth_context.py b/policyengine_api/auth_context.py new file mode 100644 index 000000000..73f135007 --- /dev/null +++ b/policyengine_api/auth_context.py @@ -0,0 +1,41 @@ +from flask import Flask, g +from werkzeug.local import LocalProxy +from authlib.integrations.flask_oauth2 import ResourceProtector + + +def configure(app: Flask, require_auth: ResourceProtector): + """ + Configure the application to attempt to get and validate a bearer token. + If there is a token and it's valid the user id is added to the request context + which can be accessed via get_user_id + Otherwise, the request is accepted but get_user_id returns None + + This supports our current auth model where only user-specific actions are restricted and + then only to allow the user + """ + + # If the user is authenticated then get the user id from the token + # And add it to the flask request context. + @app.before_request + def get_user(): + try: + token = require_auth.acquire_token() + print(f"Validated JWT for sub {g.authlib_server_oauth2_token.sub}") + except Exception as ex: + print(f"Unable to parse a valid bearer token from request: {ex}") + + +def get_user() -> None | str: + # I didn't see this documented anywhere, but if you look at the source code + # the validator stores the token in the flask global context under this name. + if "authlib_server_oauth2_token" not in g: + print( + "authlib_server_oauth2_token is not in the flask global context. Please make sure you called 'configure' on the app" + ) + return None + if "sub" not in g.authlib_server_oauth2_token: + print( + "ERROR: authlib_server_oauth2_token does not contain a sub field. The JWT validator should force this to be true." + ) + return None + return g.authlib_server_oauth2_token.sub diff --git a/policyengine_api/routes/user_profile_routes.py b/policyengine_api/routes/user_profile_routes.py index d859629c6..275a89c82 100644 --- a/policyengine_api/routes/user_profile_routes.py +++ b/policyengine_api/routes/user_profile_routes.py @@ -1,4 +1,5 @@ from flask import Blueprint, Response, request +from policyengine_api.auth_context import get_user from policyengine_api.utils.payload_validators import validate_country from policyengine_api.data import database import json @@ -9,6 +10,21 @@ user_service = UserService() +# TODO: This does nothing pending refresh of user tokens +# to include auth information. Once that happens this will throw +# a 403 unauthorized exception if the authenticated user does +# not match +def assert_auth_user_is(user_id: str): + current_user = get_user() + if current_user is None: + print("ERROR: No user is logged in. Ignoring.") + if current_user != user_id: + print( + f"ERROR: Request is autheticated as {current_user} not expected user {user_id}" + ) + return + + @user_profile_bp.route("//user-profile", methods=["POST"]) @validate_country def set_user_profile(country_id: str) -> Response: @@ -24,6 +40,8 @@ def set_user_profile(country_id: str) -> Response: username = payload.pop("username", None) user_since = payload.pop("user_since") + assert_auth_user_is(auth0_id) + created, row = user_service.create_profile( primary_country=country_id, auth0_id=auth0_id, @@ -112,6 +130,11 @@ def update_user_profile(country_id: str) -> Response: if user_id is None: raise BadRequest("Payload must include user_id") + current = user_service.get_profile(user_id=user_id) + if current is None: + raise NotFound("No such user id") + assert_auth_user_is(current.auth0_id) + updated = user_service.update_profile( user_id=user_id, primary_country=primary_country, diff --git a/policyengine_api/validator.py b/policyengine_api/validator.py new file mode 100644 index 000000000..204c62f16 --- /dev/null +++ b/policyengine_api/validator.py @@ -0,0 +1,26 @@ +# As defined by https://auth0.com/docs/quickstart/backend/python/interactive +import json +from urllib.request import urlopen + +from authlib.oauth2.rfc7523 import JWTBearerTokenValidator +from authlib.jose.rfc7517.jwk import JsonWebKey + + +class Auth0JWTBearerTokenValidator(JWTBearerTokenValidator): + def __init__( + self, + audience="https://api.policyengine.org/", + ): + issuer = "https://policyengine.uk.auth0.com/" + jsonurl = urlopen( + f"https://policyengine.uk.auth0.com/.well-known/jwks.json" + ) + public_key = JsonWebKey.import_key_set(json.loads(jsonurl.read())) + super(Auth0JWTBearerTokenValidator, self).__init__(public_key) + self.claims_options = { + "exp": {"essential": True}, + "aud": {"essential": True, "value": audience}, + "iss": {"essential": True, "value": issuer}, + # Provides the user id as we currently use it. + "sub": {"essential": True}, + } diff --git a/setup.py b/setup.py index 1219b10c8..52655771c 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "streamlit", "werkzeug", "Flask-Caching>=2,<3", + "Authlib", ], extras_require={ "dev": [