diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 673ca29..9effc80 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security +from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes @@ -202,7 +202,6 @@ def get_current_principal( # otherwise it is None. The original set of API key scopes is used for generating new # API keys. roles, scopes, api_key_scopes = {}, {}, None - if api_key is not None: if authenticators: # Tiled is in a multi-user configuration with authentication providers. @@ -356,6 +355,40 @@ def get_current_principal( return principal +def get_current_principal_websocket( + websocket: WebSocket, + scopes: str, +): + app = websocket.app + security_scopes = SecurityScopes(scopes=scopes or []) + settings = app.dependency_overrides[get_settings]() + authenticators = app.dependency_overrides[get_authenticators]() + api_access_manager = app.dependency_overrides[get_api_access_manager]() + + auth_header = websocket.headers.get("Authorization", "") + access_token, api_key = None, None + if auth_header.startswith("Bearer "): + access_token = auth_header[len("Bearer") :].strip() + if auth_header.startswith("ApiKey "): + api_key = auth_header[len("ApiKey") :].strip() + + principal = None + try: + principal = get_current_principal( + request=websocket, + security_scopes=security_scopes, + access_token=access_token, + api_key=api_key, + settings=settings, + authenticators=authenticators, + api_access_manager=api_access_manager, + ) + except HTTPException as ex: + print(f"WebSocket connection failed: {ex}") + + return principal + + def create_session(settings, identity_provider, id, scopes): with get_sessionmaker(settings.database_settings)() as db: # Have we seen this Identity before? diff --git a/bluesky_httpserver/authorization/_defaults.py b/bluesky_httpserver/authorization/_defaults.py index ecdf4dc..37448b8 100644 --- a/bluesky_httpserver/authorization/_defaults.py +++ b/bluesky_httpserver/authorization/_defaults.py @@ -73,6 +73,7 @@ "write:plan:control", "write:execute", "write:history:edit", + "user:apikeys", } _DEFAULT_SCOPES_USER = { @@ -91,6 +92,7 @@ "write:plan:control", "write:execute", "write:history:edit", + "user:apikeys", } _DEFAULT_SCOPES_OBSERVER = { @@ -103,6 +105,7 @@ "read:console", "read:lock", "read:testing", + "user:apikeys", } # ============================================================================================= diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 7eaa74e..397972b 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -14,7 +14,7 @@ else: from pydantic_settings import BaseSettings -from ..authentication import get_current_principal +from ..authentication import get_current_principal, get_current_principal_websocket from ..console_output import ConsoleOutputEventStream, StreamingResponseFromClass from ..resources import SERVER_RESOURCES as SR from ..settings import get_settings @@ -1139,7 +1139,12 @@ def is_alive(self): @router.websocket("/console_output/ws") -async def console_output_ws(websocket: WebSocket): +async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.console_output_stream.add_queue(websocket) wsmon = WebSocketMonitor(websocket) @@ -1151,6 +1156,8 @@ async def console_output_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: @@ -1158,11 +1165,17 @@ async def console_output_ws(websocket: WebSocket): @router.websocket("/status/ws") -async def status_ws(websocket: WebSocket): +async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.system_info_stream.add_queue_status(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() + try: while wsmon.is_alive: try: @@ -1170,6 +1183,8 @@ async def status_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: @@ -1177,7 +1192,12 @@ async def status_ws(websocket: WebSocket): @router.websocket("/info/ws") -async def info_ws(websocket: WebSocket): +async def info_ws(websocket: WebSocket, scopes=["read:monitor"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.system_info_stream.add_queue_info(websocket) wsmon = WebSocketMonitor(websocket) @@ -1189,6 +1209,8 @@ async def info_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py new file mode 100644 index 0000000..2a76109 --- /dev/null +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -0,0 +1,175 @@ +import json +import pprint +import threading +import time as ttime + +import pytest +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from websockets.sync.client import connect + +from .conftest import fastapi_server_fs # noqa: F401 +from .conftest import ( + SERVER_ADDRESS, + SERVER_PORT, + request_to_json, + setup_server_with_config_file, + wait_for_environment_to_be_closed, + wait_for_environment_to_be_created, +) + +config_toy_test = """ +authentication: + allow_anonymous_access: True + providers: + - provider: toy + authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + args: + users_to_passwords: + bob: bob_password + alice: alice_password + cara: cara_password + tom: tom_password +api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + bob: + roles: + - admin + - expert + alice: + roles: advanced + tom: + roles: user +""" + + +class _ReceiveSystemInfoSocket(threading.Thread): + """ + Catch streaming console output by connecting to /console_output/ws socket and + save messages to the buffer. + """ + + def __init__(self, *, endpoint, api_key=None, token=None, **kwargs): + super().__init__(**kwargs) + self.received_data_buffer = [] + self._exit = False + self._api_key = api_key + self._token = token + self._endpoint = endpoint + + def run(self): + websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" + if self._token is not None: + additional_headers = {"Authorization": f"Bearer {self._token}"} + elif self._api_key is not None: + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + else: + additional_headers = {} + + try: + with connect(websocket_uri, additional_headers=additional_headers) as websocket: + while not self._exit: + try: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: + pass + except Exception as ex: + print(f"Failed to connect to server: {ex}") + + def stop(self): + """ + Call this method to stop the thread. Then send a request to the server so that some output + is printed in ``stdout``. + """ + self._exit = True + + def __del__(self): + self.stop() + + +# fmt: off +@pytest.mark.parametrize("ws_auth_type", ["apikey", "token", "apikey_invalid", "token_invalid", "none"]) +# fmt: on +def test_websocket_auth_01( + tmpdir, + monkeypatch, + re_manager_cmd, # noqa: F811 + fastapi_server_fs, # noqa: F811 + ws_auth_type, +): + """ + Test authentication for websockets. The test is run only on ``/status/ws`` websocket. + The other websockets are expected to use the same authentication scheme. + """ + + # Start RE Manager + params = ["--zmq-publish-console", "ON"] + re_manager_cmd(params) + + setup_server_with_config_file(config_file_str=config_toy_test, tmpdir=tmpdir, monkeypatch=monkeypatch) + fastapi_server_fs() + + resp1 = request_to_json("post", "/auth/provider/toy/token", login=("bob", "bob_password")) + assert "access_token" in pprint.pformat(resp1) + token = resp1["access_token"] + + resp3 = request_to_json( + "post", "/auth/apikey", json={"expires_in": 900, "note": "API key for testing"}, token=token + ) + assert "secret" in resp3, pprint.pformat(resp3) + assert "note" in resp3, pprint.pformat(resp3) + assert resp3["note"] == "API key for testing" + assert resp3["scopes"] == ["inherit"] + api_key = resp3["secret"] + + endpoint = "/status/ws" + if ws_auth_type == "none": + ws_params = {} + elif ws_auth_type == "apikey": + ws_params = {"api_key": api_key} + elif ws_auth_type == "apikey_invalid": + ws_params = {"api_key": "InvalidApiKey"} + elif ws_auth_type == "token": + ws_params = {"token": token} + elif ws_auth_type == "token_invalid": + ws_params = {"token": "InvalidToken"} + else: + assert False, f"Unknown authentication type: {ws_auth_type!r}" + + rsc = _ReceiveSystemInfoSocket(endpoint=endpoint, **ws_params) + rsc.start() + ttime.sleep(1) # Wait until the client connects to the socket + + resp1 = request_to_json("post", "/environment/open", api_key=api_key) + assert resp1["success"] is True, pprint.pformat(resp1) + + assert wait_for_environment_to_be_created(timeout=10, api_key=api_key) + + resp2b = request_to_json("post", "/environment/close", api_key=api_key) + assert resp2b["success"] is True, pprint.pformat(resp2b) + + assert wait_for_environment_to_be_closed(timeout=10, api_key=api_key) + + # Wait until capture is complete + ttime.sleep(2) + rsc.stop() + rsc.join() + + buffer = rsc.received_data_buffer + if ws_auth_type in ("none", "apikey_invalid", "token_invalid"): + assert len(buffer) == 0 + elif ws_auth_type in ("apikey", "token"): + assert len(buffer) > 0 + for msg in buffer: + assert "time" in msg, msg + assert isinstance(msg["time"], float), msg + assert "msg" in msg + assert isinstance(msg["msg"], dict) + else: + assert False, f"Unknown authentication type: {ws_auth_type!r}" diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index 1b87e53..1f089ec 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -353,7 +353,8 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): def run(self): websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api/console_output/ws" - with connect(websocket_uri) as websocket: + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + with connect(websocket_uri, additional_headers=additional_headers) as websocket: while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) diff --git a/bluesky_httpserver/tests/test_system_info_socket.py b/bluesky_httpserver/tests/test_system_info_socket.py index b20c98c..4d25dd6 100644 --- a/bluesky_httpserver/tests/test_system_info_socket.py +++ b/bluesky_httpserver/tests/test_system_info_socket.py @@ -35,17 +35,21 @@ def __init__(self, *, endpoint, api_key=API_KEY_FOR_TESTS, **kwargs): def run(self): websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" - with connect(websocket_uri) as websocket: - while not self._exit: - try: - msg_json = websocket.recv(timeout=0.1, decode=False) + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + try: + with connect(websocket_uri, additional_headers=additional_headers) as websocket: + while not self._exit: try: - msg = json.loads(msg_json) - self.received_data_buffer.append(msg) - except json.JSONDecodeError: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: pass - except TimeoutError: - pass + except Exception as ex: + print(f"Failed to connect to server: {ex}") def stop(self): """