diff --git a/mcp/src/openisle_mcp/server.py b/mcp/src/openisle_mcp/server.py index 9e4670701..25a8f8331 100644 --- a/mcp/src/openisle_mcp/server.py +++ b/mcp/src/openisle_mcp/server.py @@ -4,12 +4,13 @@ from __future__ import annotations import logging from contextlib import asynccontextmanager -from typing import Annotated +from typing import Annotated, Any import httpx from mcp.server.fastmcp import Context, FastMCP from pydantic import ValidationError from pydantic import Field as PydanticField +from weakref import WeakKeyDictionary from .config import get_settings from .schemas import ( @@ -50,6 +51,67 @@ search_client = SearchClient( ) +class SessionTokenManager: + """Cache JWT access tokens on a per-session basis.""" + + def __init__(self) -> None: + self._tokens: WeakKeyDictionary[Any, str] = WeakKeyDictionary() + + def resolve(self, ctx: Context | None, token: str | None) -> str | None: + """Resolve and optionally persist the token for the current session.""" + + session = self._get_session(ctx) + + if isinstance(token, str): + stripped = token.strip() + if stripped: + if session is not None: + self._tokens[session] = stripped + logger.debug( + "Stored JWT token for session %s.", + self._describe_session(session), + ) + return stripped + + if session is not None and session in self._tokens: + logger.debug( + "Clearing stored JWT token for session %s due to empty input.", + self._describe_session(session), + ) + del self._tokens[session] + return None + + if session is not None: + cached = self._tokens.get(session) + if cached: + logger.debug( + "Reusing cached JWT token for session %s.", + self._describe_session(session), + ) + return cached + + return None + + @staticmethod + def _get_session(ctx: Context | None) -> Any | None: + if ctx is None: + return None + try: + return ctx.session + except Exception: # pragma: no cover - defensive guard + return None + + @staticmethod + def _describe_session(session: Any) -> str: + identifier = getattr(session, "mcp_session_id", None) + if isinstance(identifier, str) and identifier: + return identifier + return hex(id(session)) + + +session_token_manager = SessionTokenManager() + + @asynccontextmanager async def lifespan(_: FastMCP): """Lifecycle hook that disposes shared resources when the server stops.""" @@ -164,10 +226,10 @@ async def reply_to_post( if not sanitized_content: raise ValueError("Reply content must not be empty.") - sanitized_token = token.strip() if isinstance(token, str) else None - sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None + resolved_token = session_token_manager.resolve(ctx, token) + try: logger.info( "Creating reply for post_id=%s (captcha=%s)", @@ -176,7 +238,7 @@ async def reply_to_post( ) raw_comment = await search_client.reply_to_post( post_id, - sanitized_token, + resolved_token, sanitized_content, sanitized_captcha, ) @@ -271,10 +333,10 @@ async def reply_to_comment( if not sanitized_content: raise ValueError("Reply content must not be empty.") - sanitized_token = token.strip() if isinstance(token, str) else None - sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None + resolved_token = session_token_manager.resolve(ctx, token) + try: logger.info( "Creating reply for comment_id=%s (captcha=%s)", @@ -283,7 +345,7 @@ async def reply_to_comment( ) raw_comment = await search_client.reply_to_comment( comment_id, - sanitized_token, + resolved_token, sanitized_content, sanitized_captcha, ) @@ -411,13 +473,11 @@ async def get_post( ) -> PostDetail: """Fetch post details from the backend and validate the response.""" - sanitized_token = token.strip() if isinstance(token, str) else None - if sanitized_token == "": - sanitized_token = None + resolved_token = session_token_manager.resolve(ctx, token) try: logger.info("Fetching post details for post_id=%s", post_id) - raw_post = await search_client.get_post(post_id, sanitized_token) + raw_post = await search_client.get_post(post_id, resolved_token) except httpx.HTTPStatusError as exc: # pragma: no cover - network errors status_code = exc.response.status_code if status_code == 404: @@ -495,7 +555,7 @@ async def list_unread_messages( ) -> UnreadNotificationsResponse: """Retrieve unread notifications and return structured data.""" - sanitized_token = token.strip() if isinstance(token, str) else None + resolved_token = session_token_manager.resolve(ctx, token) try: logger.info( @@ -506,7 +566,7 @@ async def list_unread_messages( raw_notifications = await search_client.list_unread_notifications( page=page, size=size, - token=sanitized_token, + token=resolved_token, ) except httpx.HTTPStatusError as exc: # pragma: no cover - network errors message = (