mirror of
https://github.com/nagisa77/OpenIsle.git
synced 2026-05-11 13:17:29 +08:00
Merge pull request #1098 from nagisa77/codex/store-accesstoken-as-jwt-token
Cache MCP session JWT tokens
This commit is contained in:
@@ -4,12 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydantic import Field as PydanticField
|
from pydantic import Field as PydanticField
|
||||||
|
from weakref import WeakKeyDictionary
|
||||||
|
|
||||||
from .config import get_settings
|
from .config import get_settings
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
@@ -50,6 +51,67 @@ search_client = SearchClient(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionTokenManager:
|
||||||
|
"""Cache JWT access tokens on a per-session basis."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tokens: WeakKeyDictionary[Any, str] = WeakKeyDictionary()
|
||||||
|
|
||||||
|
def resolve(self, ctx: Context | None, token: str | None) -> str | None:
|
||||||
|
"""Resolve and optionally persist the token for the current session."""
|
||||||
|
|
||||||
|
session = self._get_session(ctx)
|
||||||
|
|
||||||
|
if isinstance(token, str):
|
||||||
|
stripped = token.strip()
|
||||||
|
if stripped:
|
||||||
|
if session is not None:
|
||||||
|
self._tokens[session] = stripped
|
||||||
|
logger.debug(
|
||||||
|
"Stored JWT token for session %s.",
|
||||||
|
self._describe_session(session),
|
||||||
|
)
|
||||||
|
return stripped
|
||||||
|
|
||||||
|
if session is not None and session in self._tokens:
|
||||||
|
logger.debug(
|
||||||
|
"Clearing stored JWT token for session %s due to empty input.",
|
||||||
|
self._describe_session(session),
|
||||||
|
)
|
||||||
|
del self._tokens[session]
|
||||||
|
return None
|
||||||
|
|
||||||
|
if session is not None:
|
||||||
|
cached = self._tokens.get(session)
|
||||||
|
if cached:
|
||||||
|
logger.debug(
|
||||||
|
"Reusing cached JWT token for session %s.",
|
||||||
|
self._describe_session(session),
|
||||||
|
)
|
||||||
|
return cached
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_session(ctx: Context | None) -> Any | None:
|
||||||
|
if ctx is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return ctx.session
|
||||||
|
except Exception: # pragma: no cover - defensive guard
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _describe_session(session: Any) -> str:
|
||||||
|
identifier = getattr(session, "mcp_session_id", None)
|
||||||
|
if isinstance(identifier, str) and identifier:
|
||||||
|
return identifier
|
||||||
|
return hex(id(session))
|
||||||
|
|
||||||
|
|
||||||
|
session_token_manager = SessionTokenManager()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_: FastMCP):
|
async def lifespan(_: FastMCP):
|
||||||
"""Lifecycle hook that disposes shared resources when the server stops."""
|
"""Lifecycle hook that disposes shared resources when the server stops."""
|
||||||
@@ -164,10 +226,10 @@ async def reply_to_post(
|
|||||||
if not sanitized_content:
|
if not sanitized_content:
|
||||||
raise ValueError("Reply content must not be empty.")
|
raise ValueError("Reply content must not be empty.")
|
||||||
|
|
||||||
sanitized_token = token.strip() if isinstance(token, str) else None
|
|
||||||
|
|
||||||
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
|
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
|
||||||
|
|
||||||
|
resolved_token = session_token_manager.resolve(ctx, token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Creating reply for post_id=%s (captcha=%s)",
|
"Creating reply for post_id=%s (captcha=%s)",
|
||||||
@@ -176,7 +238,7 @@ async def reply_to_post(
|
|||||||
)
|
)
|
||||||
raw_comment = await search_client.reply_to_post(
|
raw_comment = await search_client.reply_to_post(
|
||||||
post_id,
|
post_id,
|
||||||
sanitized_token,
|
resolved_token,
|
||||||
sanitized_content,
|
sanitized_content,
|
||||||
sanitized_captcha,
|
sanitized_captcha,
|
||||||
)
|
)
|
||||||
@@ -271,10 +333,10 @@ async def reply_to_comment(
|
|||||||
if not sanitized_content:
|
if not sanitized_content:
|
||||||
raise ValueError("Reply content must not be empty.")
|
raise ValueError("Reply content must not be empty.")
|
||||||
|
|
||||||
sanitized_token = token.strip() if isinstance(token, str) else None
|
|
||||||
|
|
||||||
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
|
sanitized_captcha = captcha.strip() if isinstance(captcha, str) else None
|
||||||
|
|
||||||
|
resolved_token = session_token_manager.resolve(ctx, token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Creating reply for comment_id=%s (captcha=%s)",
|
"Creating reply for comment_id=%s (captcha=%s)",
|
||||||
@@ -283,7 +345,7 @@ async def reply_to_comment(
|
|||||||
)
|
)
|
||||||
raw_comment = await search_client.reply_to_comment(
|
raw_comment = await search_client.reply_to_comment(
|
||||||
comment_id,
|
comment_id,
|
||||||
sanitized_token,
|
resolved_token,
|
||||||
sanitized_content,
|
sanitized_content,
|
||||||
sanitized_captcha,
|
sanitized_captcha,
|
||||||
)
|
)
|
||||||
@@ -411,13 +473,11 @@ async def get_post(
|
|||||||
) -> PostDetail:
|
) -> PostDetail:
|
||||||
"""Fetch post details from the backend and validate the response."""
|
"""Fetch post details from the backend and validate the response."""
|
||||||
|
|
||||||
sanitized_token = token.strip() if isinstance(token, str) else None
|
resolved_token = session_token_manager.resolve(ctx, token)
|
||||||
if sanitized_token == "":
|
|
||||||
sanitized_token = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Fetching post details for post_id=%s", post_id)
|
logger.info("Fetching post details for post_id=%s", post_id)
|
||||||
raw_post = await search_client.get_post(post_id, sanitized_token)
|
raw_post = await search_client.get_post(post_id, resolved_token)
|
||||||
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
|
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
|
||||||
status_code = exc.response.status_code
|
status_code = exc.response.status_code
|
||||||
if status_code == 404:
|
if status_code == 404:
|
||||||
@@ -495,7 +555,7 @@ async def list_unread_messages(
|
|||||||
) -> UnreadNotificationsResponse:
|
) -> UnreadNotificationsResponse:
|
||||||
"""Retrieve unread notifications and return structured data."""
|
"""Retrieve unread notifications and return structured data."""
|
||||||
|
|
||||||
sanitized_token = token.strip() if isinstance(token, str) else None
|
resolved_token = session_token_manager.resolve(ctx, token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -506,7 +566,7 @@ async def list_unread_messages(
|
|||||||
raw_notifications = await search_client.list_unread_notifications(
|
raw_notifications = await search_client.list_unread_notifications(
|
||||||
page=page,
|
page=page,
|
||||||
size=size,
|
size=size,
|
||||||
token=sanitized_token,
|
token=resolved_token,
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
|
except httpx.HTTPStatusError as exc: # pragma: no cover - network errors
|
||||||
message = (
|
message = (
|
||||||
|
|||||||
Reference in New Issue
Block a user