From 749ab560fff9b1d2b23d34437acf1493b3de3c20 Mon Sep 17 00:00:00 2001 From: tim Date: Tue, 28 Oct 2025 01:55:46 +0800 Subject: [PATCH] Revert "Cache MCP session JWT tokens" This reverts commit 997dacdbe6a70e07fde54faa3bf59f5ebe3183f0. --- mcp/src/openisle_mcp/server.py | 86 +++++----------------------------- 1 file changed, 13 insertions(+), 73 deletions(-) diff --git a/mcp/src/openisle_mcp/server.py b/mcp/src/openisle_mcp/server.py index 25a8f8331..9e4670701 100644 --- a/mcp/src/openisle_mcp/server.py +++ b/mcp/src/openisle_mcp/server.py @@ -4,13 +4,12 @@ from __future__ import annotations import logging from contextlib import asynccontextmanager -from typing import Annotated, Any +from typing import Annotated import httpx from mcp.server.fastmcp import Context, FastMCP from pydantic import ValidationError from pydantic import Field as PydanticField -from weakref import WeakKeyDictionary from .config import get_settings from .schemas import ( @@ -51,67 +50,6 @@ search_client = SearchClient( ) -class SessionTokenManager: - """Cache JWT access tokens on a per-session basis.""" - - def __init__(self) -> None: - self._tokens: WeakKeyDictionary[Any, str] = WeakKeyDictionary() - - def resolve(self, ctx: Context | None, token: str | None) -> str | None: - """Resolve and optionally persist the token for the current session.""" - - session = self._get_session(ctx) - - if isinstance(token, str): - stripped = token.strip() - if stripped: - if session is not None: - self._tokens[session] = stripped - logger.debug( - "Stored JWT token for session %s.", - self._describe_session(session), - ) - return stripped - - if session is not None and session in self._tokens: - logger.debug( - "Clearing stored JWT token for session %s due to empty input.", - self._describe_session(session), - ) - del self._tokens[session] - return None - - if session is not None: - cached = self._tokens.get(session) - if cached: - logger.debug( - "Reusing cached JWT token for session %s.", - self._describe_session(session), - ) - return cached - - return None - - @staticmethod - def _get_session(ctx: Context | None) -> Any | None: - if ctx is None: - return None - try: - return ctx.session - except Exception: # pragma: no cover - defensive guard - return None - - @staticmethod - def _describe_session(session: Any) -> str: - identifier = getattr(session, "mcp_session_id", None) - if isinstance(identifier, str) and identifier: - return identifier - return hex(id(session)) - - -session_token_manager = SessionTokenManager() - - @asynccontextmanager async def lifespan(_: FastMCP): """Lifecycle hook that disposes shared resources when the server stops.""" @@ -226,9 +164,9 @@ async def reply_to_post( if not sanitized_content: raise ValueError("Reply content must not be empty.") - sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None + sanitized_token = token.strip() if isinstance(token, str) else None - resolved_token = session_token_manager.resolve(ctx, token) + sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None try: logger.info( @@ -238,7 +176,7 @@ async def reply_to_post( ) raw_comment = await search_client.reply_to_post( post_id, - resolved_token, + sanitized_token, sanitized_content, sanitized_captcha, ) @@ -333,9 +271,9 @@ async def reply_to_comment( if not sanitized_content: raise ValueError("Reply content must not be empty.") - sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None + sanitized_token = token.strip() if isinstance(token, str) else None - resolved_token = session_token_manager.resolve(ctx, token) + sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None try: logger.info( @@ -345,7 +283,7 @@ async def reply_to_comment( ) raw_comment = await search_client.reply_to_comment( comment_id, - resolved_token, + sanitized_token, sanitized_content, sanitized_captcha, ) @@ -473,11 +411,13 @@ async def get_post( ) -> PostDetail: """Fetch post details from the backend and validate the response.""" - resolved_token = session_token_manager.resolve(ctx, token) + sanitized_token = token.strip() if isinstance(token, str) else None + if sanitized_token == "": + sanitized_token = None try: logger.info("Fetching post details for post_id=%s", post_id) - raw_post = await search_client.get_post(post_id, resolved_token) + raw_post = await search_client.get_post(post_id, sanitized_token) except httpx.HTTPStatusError as exc: # pragma: no cover - network errors status_code = exc.response.status_code if status_code == 404: @@ -555,7 +495,7 @@ async def list_unread_messages( ) -> UnreadNotificationsResponse: """Retrieve unread notifications and return structured data.""" - resolved_token = session_token_manager.resolve(ctx, token) + sanitized_token = token.strip() if isinstance(token, str) else None try: logger.info( @@ -566,7 +506,7 @@ async def list_unread_messages( raw_notifications = await search_client.list_unread_notifications( page=page, size=size, - token=resolved_token, + token=sanitized_token, ) except httpx.HTTPStatusError as exc: # pragma: no cover - network errors message = (