Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
"""Store client information."""
...

async def get_oauth_metadata(self) -> OAuthMetadata | None:
"""Get stored authorization server metadata.

Optional: implementations may return ``None`` if metadata persistence
is not desired. Implementations that persist tokens across restarts
should also persist metadata so :meth:`OAuthClientProvider._refresh_token`
can resolve the correct token endpoint without rediscovering metadata
on every restart.
"""
return None

async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None:
"""Store authorization server metadata.

Optional: no-op by default. See :meth:`get_oauth_metadata`.
"""
return


@dataclass
class OAuthContext:
Expand Down Expand Up @@ -473,10 +491,19 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p
self.context.clear_tokens()
return False

async def _initialize(self) -> None: # pragma: no cover
"""Load stored tokens and client info."""
async def _initialize(self) -> None:
"""Load stored tokens, client info, and authorization server metadata."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = await self.context.storage.get_client_info()
# Restore authorization server metadata so ``_refresh_token`` can
# resolve the correct token endpoint without rediscovering it on
# every restart. ``getattr`` preserves backward compatibility with
# storage implementations predating ``get_oauth_metadata``: they
# return ``None`` and the refresh path falls back to the legacy
# ``<base_url>/token`` behaviour as before.
meta_getter = getattr(self.context.storage, "get_oauth_metadata", None)
if meta_getter is not None:
self.context.oauth_metadata = await meta_getter()
self._initialized = True

def _add_auth_header(self, request: httpx.Request) -> None:
Expand Down Expand Up @@ -507,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
"""HTTPX auth flow integration."""
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
await self._initialize()

# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
Expand Down Expand Up @@ -572,6 +599,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
break
if ok and asm:
self.context.oauth_metadata = asm
# Persist so subsequent restarts can resolve the
# correct token endpoint without rediscovery.
meta_setter = getattr(self.context.storage, "set_oauth_metadata", None)
if meta_setter is not None:
await meta_setter(asm)
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")
Expand Down Expand Up @@ -612,7 +644,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
except Exception:
logger.exception("OAuth flow error")
raise

Expand Down
260 changes: 258 additions & 2 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,26 @@ class MockTokenStorage:
def __init__(self):
self._tokens: OAuthToken | None = None
self._client_info: OAuthClientInformationFull | None = None
self._oauth_metadata: OAuthMetadata | None = None

async def get_tokens(self) -> OAuthToken | None:
return self._tokens # pragma: no cover
return self._tokens

async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens

async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info # pragma: no cover
return self._client_info

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info

async def get_oauth_metadata(self) -> OAuthMetadata | None:
return self._oauth_metadata

async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None:
self._oauth_metadata = metadata


@pytest.fixture
def mock_storage():
Expand Down Expand Up @@ -2618,3 +2625,252 @@ async def callback_handler() -> tuple[str, str | None]:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass


# --- Regression coverage for #1318: restore oauth_metadata on _initialize ---


@pytest.mark.anyio
async def test_initialize_restores_oauth_metadata(
oauth_provider: OAuthClientProvider,
mock_storage: MockTokenStorage,
):
"""``_initialize`` should restore ``oauth_metadata`` from storage.

Without this, ``_refresh_token`` loses the authoritative token endpoint
discovered during the prior session and falls back to ``<base_url>/token``
after every restart — a 404 for servers whose token endpoint sits on a
non-standard path.
"""
stored_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
token_endpoint=AnyHttpUrl("https://auth.example.com/oauth/v3/token"),
)
await mock_storage.set_oauth_metadata(stored_metadata)

await oauth_provider._initialize()

assert oauth_provider.context.oauth_metadata is not None
assert str(oauth_provider.context.oauth_metadata.token_endpoint) == ("https://auth.example.com/oauth/v3/token")


@pytest.mark.anyio
async def test_refresh_token_uses_persisted_metadata_endpoint(
oauth_provider: OAuthClientProvider,
mock_storage: MockTokenStorage,
valid_tokens: OAuthToken,
):
"""After a restart with persisted metadata, ``_refresh_token`` uses the
correct ``token_endpoint`` rather than the ``<base_url>/token`` fallback.
"""
custom_token_endpoint = "https://auth.example.com/oauth/v3/token"
await mock_storage.set_tokens(valid_tokens)
await mock_storage.set_client_info(
OAuthClientInformationFull(
client_id="test_client",
client_secret="test_secret",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
token_endpoint_auth_method="client_secret_post",
)
)
await mock_storage.set_oauth_metadata(
OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
token_endpoint=AnyHttpUrl(custom_token_endpoint),
)
)

await oauth_provider._initialize()
request = await oauth_provider._refresh_token()

assert str(request.url) == custom_token_endpoint


@pytest.mark.anyio
async def test_initialize_backward_compat_without_metadata_methods(
client_metadata: OAuthClientMetadata,
valid_tokens: OAuthToken,
):
"""Storage implementations predating ``get_oauth_metadata`` keep working.

Duck-typed ``TokenStorage`` instances written before this method was
introduced must not raise ``AttributeError`` on ``_initialize``.
"""

class LegacyStorage:
"""Duck-typed storage matching the pre-change ``TokenStorage``."""

def __init__(self, tokens: OAuthToken | None):
self._tokens = tokens
self._client_info: OAuthClientInformationFull | None = None

async def get_tokens(self) -> OAuthToken | None:
return self._tokens

async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens # pragma: no cover

async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info # pragma: no cover

legacy_storage = LegacyStorage(valid_tokens)

async def redirect_handler(url: str) -> None:
pass # pragma: no cover

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state" # pragma: no cover

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=legacy_storage, # type: ignore[arg-type]
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

await provider._initialize()

assert provider.context.current_tokens is valid_tokens
assert provider.context.oauth_metadata is None


@pytest.mark.anyio
async def test_token_storage_protocol_default_metadata_methods():
"""``TokenStorage`` provides no-op defaults for the optional metadata methods.

Storage subclasses that don't care about metadata persistence can inherit
``TokenStorage`` without overriding ``get_oauth_metadata`` /
``set_oauth_metadata``; the default ``get`` returns ``None`` and the
default ``set`` is a no-op (equivalent to opting out of persistence).
"""
from mcp.client.auth.oauth2 import TokenStorage

class DefaultStorage(TokenStorage):
async def get_tokens(self) -> OAuthToken | None:
return None # pragma: no cover

async def set_tokens(self, tokens: OAuthToken) -> None: ... # pragma: no cover

async def get_client_info(self) -> OAuthClientInformationFull | None:
return None # pragma: no cover

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: ... # pragma: no cover

storage = DefaultStorage()
assert await storage.get_oauth_metadata() is None

metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
)
# No-op: set completes without storing
await storage.set_oauth_metadata(metadata)
assert await storage.get_oauth_metadata() is None


@pytest.mark.anyio
async def test_auth_flow_discovery_with_legacy_storage_skips_metadata_persistence(
client_metadata: OAuthClientMetadata,
):
"""OAuth discovery succeeds when storage lacks ``set_oauth_metadata``.

Covers the ``getattr`` fallback branch in ``async_auth_flow`` that bypasses
persistence for storage implementations predating the metadata API.
"""

class LegacyStorage:
"""Duck-typed storage matching the pre-change ``TokenStorage``."""

def __init__(self) -> None:
self._tokens: OAuthToken | None = None
self._client_info: OAuthClientInformationFull | None = None

async def get_tokens(self) -> OAuthToken | None:
return self._tokens

async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens # pragma: no cover

async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info # pragma: no cover

legacy_storage = LegacyStorage()

async def redirect_handler(url: str) -> None:
pass # pragma: no cover

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state" # pragma: no cover

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=legacy_storage, # type: ignore[arg-type]
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

test_request = httpx.Request("GET", "https://api.example.com/mcp")
auth_flow = provider.async_auth_flow(test_request)

# First yield: ``_initialize`` loads state from LegacyStorage (no tokens,
# no client info, no metadata fallback) and the original request goes out
# without an auth header.
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# 401 → triggers full OAuth flow
unauthorized_response = httpx.Response(
401,
headers={
"WWW-Authenticate": (
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
)
},
request=test_request,
)
prm_request = await auth_flow.asend(unauthorized_response)
assert "oauth-protected-resource" in str(prm_request.url)

# PRM discovery response
prm_response = httpx.Response(
200,
content=(
b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
),
request=prm_request,
)
asm_request = await auth_flow.asend(prm_response)
assert str(asm_request.url).startswith("https://auth.example.com/")

# OASM discovery response — this is where our set_oauth_metadata
# fallback (meta_setter is None) executes for LegacyStorage.
asm_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token", '
b'"registration_endpoint": "https://auth.example.com/register"}'
),
request=asm_request,
)
next_request = await auth_flow.asend(asm_response)

# Discovery succeeded: flow advanced past metadata handling.
# (Legacy storage had no set_oauth_metadata, so persistence is skipped.)
assert next_request is not None
assert provider.context.oauth_metadata is not None
assert str(provider.context.oauth_metadata.token_endpoint) == "https://auth.example.com/token"

await auth_flow.aclose()