Cache MCP session JWT tokens

This commit is contained in:
Tim
2025-10-28 01:20:32 +08:00
parent c01349a436
commit 997dacdbe6

View File

@@ -4,12 +4,13 @@ from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from typing import Annotated
from typing import Annotated, Any
import httpx
from mcp.server.fastmcp import Context, FastMCP
from pydantic import ValidationError
from pydantic import Field as PydanticField
from weakref import WeakKeyDictionary
from .config import get_settings
from .schemas import (
@@ -50,6 +51,67 @@ search_client = SearchClient(
)
class SessionTokenManager:
"""Cache JWT access tokens on a per-session basis."""
def __init__(self) -> None:
self._tokens: WeakKeyDictionary[Any, str] = WeakKeyDictionary()
def resolve(self, ctx: Context | None, token: str | None) -> str | None:
"""Resolve and optionally persist the token for the current session."""
session = self._get_session(ctx)
if isinstance(token, str):
stripped = token.strip()
if stripped:
if session is not None:
self._tokens[session] = stripped
logger.debug(
"Stored JWT token for session %s.",
self._describe_session(session),
)
return stripped
if session is not None and session in self._tokens:
logger.debug(
"Clearing stored JWT token for session %s due to empty input.",
self._describe_session(session),
)
del self._tokens[session]
return None
if session is not None:
cached = self._tokens.get(session)
if cached:
logger.debug(
"Reusing cached JWT token for session %s.",
self._describe_session(session),
)
return cached
return None
@staticmethod
def _get_session(ctx: Context | None) -> Any | None:
if ctx is None:
return None
try:
return ctx.session
except Exception: # pragma: no cover - defensive guard
return None
@staticmethod
def _describe_session(session: Any) -> str:
identifier = getattr(session, "mcp_session_id", None)
if isinstance(identifier, str) and identifier:
return identifier
return hex(id(session))
session_token_manager = SessionTokenManager()
@asynccontextmanager
async def lifespan(_: FastMCP):
"""Lifecycle hook that disposes shared resources when the server stops."""
@@ -164,10 +226,10 @@ async def reply_to_post(
if not sanitized_content:
raise ValueError("Reply content must not be empty.")
sanitized_token = token.strip() if isinstance(token, str) else None
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
resolved_token = session_token_manager.resolve(ctx, token)
try:
logger.info(
"Creating reply for post_id=%s (captcha=%s)",
@@ -176,7 +238,7 @@ async def reply_to_post(
)
raw_comment = await search_client.reply_to_post(
post_id,
sanitized_token,
resolved_token,
sanitized_content,
sanitized_captcha,
)
@@ -271,10 +333,10 @@ async def reply_to_comment(
if not sanitized_content:
raise ValueError("Reply content must not be empty.")
sanitized_token = token.strip() if isinstance(token, str) else None
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
resolved_token = session_token_manager.resolve(ctx, token)
try:
logger.info(
"Creating reply for comment_id=%s (captcha=%s)",
@@ -283,7 +345,7 @@ async def reply_to_comment(
)
raw_comment = await search_client.reply_to_comment(
comment_id,
sanitized_token,
resolved_token,
sanitized_content,
sanitized_captcha,
)
@@ -411,13 +473,11 @@ async def get_post(
) -> PostDetail:
"""Fetch post details from the backend and validate the response."""
sanitized_token = token.strip() if isinstance(token, str) else None
if sanitized_token == "":
sanitized_token = None
resolved_token = session_token_manager.resolve(ctx, token)
try:
logger.info("Fetching post details for post_id=%s", post_id)
raw_post = await search_client.get_post(post_id, sanitized_token)
raw_post = await search_client.get_post(post_id, resolved_token)
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
status_code = exc.response.status_code
if status_code == 404:
@@ -495,7 +555,7 @@ async def list_unread_messages(
) -> UnreadNotificationsResponse:
"""Retrieve unread notifications and return structured data."""
sanitized_token = token.strip() if isinstance(token, str) else None
resolved_token = session_token_manager.resolve(ctx, token)
try:
logger.info(
@@ -506,7 +566,7 @@ async def list_unread_messages(
raw_notifications = await search_client.list_unread_notifications(
page=page,
size=size,
token=sanitized_token,
token=resolved_token,
)
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
message = (