diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..b5308f1b5 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -1,12 +1,15 @@ +import asyncio import logging import os import sys +from collections.abc import AsyncIterator, Callable, Coroutine from contextlib import asynccontextmanager from pathlib import Path -from typing import Literal, TextIO +from typing import Any, Literal, TextIO import anyio import anyio.lowlevel +import sniffio from anyio.abc import Process from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream @@ -101,6 +104,89 @@ class StdioServerParameters(BaseModel): """ +@asynccontextmanager +async def _asyncio_background_tasks( + stdout_reader: Callable[[], Coroutine[Any, Any, None]], + stdin_writer: Callable[[], Coroutine[Any, Any, None]], + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], +) -> AsyncIterator[None]: + """Spawn the stdio reader/writer as top-level asyncio tasks (see #577). + + The tasks are detached from the caller's cancel-scope stack, which + is what lets callers clean up multiple transports in arbitrary + order without tripping anyio's LIFO cancel-scope check. + + If a background task crashes while the caller is still inside the + yield, the memory streams are closed via ``add_done_callback`` so + in-flight reads wake up with ``ClosedResourceError`` instead of + hanging forever. Any non-cancellation, non-closed-resource + exception from the tasks is re-raised on exit so crashes do not + go unnoticed — matching the exception propagation an anyio task + group would have given. + """ + + def _on_done(task: asyncio.Task[None]) -> None: + if task.cancelled(): # pragma: no cover - normal teardown exits reader/writer cleanly + return + exc = task.exception() + if exc is None: + return + logger.debug( + "stdio_client background task raised %s — closing streams to wake up caller", + type(exc).__name__, + exc_info=exc, + ) + for stream in (read_stream_writer, write_stream): + try: + stream.close() + except Exception: # pragma: no cover + pass + + stdout_task: asyncio.Task[None] = asyncio.create_task(stdout_reader()) + stdin_task: asyncio.Task[None] = asyncio.create_task(stdin_writer()) + stdout_task.add_done_callback(_on_done) + stdin_task.add_done_callback(_on_done) + tasks = (stdout_task, stdin_task) + try: + yield + finally: + for task in tasks: + if not task.done(): # pragma: no cover - tasks normally exit via stream close + task.cancel() + pending_exc: BaseException | None = None + for task in tasks: + try: + await task + except asyncio.CancelledError: # pragma: lax no cover - timing-dependent on teardown races + pass + except anyio.ClosedResourceError: # pragma: lax no cover - timing-dependent on teardown races + pass + except BaseException as exc: # noqa: BLE001 + if pending_exc is None: # pragma: no branch + pending_exc = exc + if pending_exc is not None: + raise pending_exc + + +@asynccontextmanager +async def _anyio_task_group_background( + stdout_reader: Callable[[], Coroutine[Any, Any, None]], + stdin_writer: Callable[[], Coroutine[Any, Any, None]], +) -> AsyncIterator[None]: + """Structured-concurrency fallback for backends other than asyncio. + + Trio forbids orphan tasks by design, so the historical task-group + pattern is retained here. Callers on trio must clean up multiple + transports in LIFO order; cross-task cleanup (#577) cannot be + fixed on that backend without violating its concurrency model. + """ + async with anyio.create_task_group() as tg: + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + yield + + @asynccontextmanager async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): """Client transport for stdio: this will connect to a server by spawning a @@ -177,38 +263,51 @@ async def stdin_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg, process: - tg.start_soon(stdout_reader) - tg.start_soon(stdin_writer) + async def _cleanup_process_and_streams() -> None: + # MCP spec: stdio shutdown sequence + # 1. Close input stream to server + # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time + # 3. Send SIGKILL if still not exited + if process.stdin: # pragma: no branch + try: + await process.stdin.aclose() + except Exception: # pragma: no cover + # stdin might already be closed, which is fine + pass + + try: + # Give the process time to exit gracefully after stdin closes + with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): + await process.wait() + except TimeoutError: + # Process didn't exit from stdin closure, use platform-specific termination + # which handles SIGTERM -> SIGKILL escalation + await _terminate_process_tree(process) + except ProcessLookupError: # pragma: no cover + # Process already exited, which is fine + pass + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + + # On asyncio we spawn the reader / writer with asyncio.create_task rather + # than an anyio task group, so their cancel scopes are not bound to the + # caller's task. That is what lets callers clean up multiple transports + # in arbitrary order — see #577. On structured-concurrency backends + # (trio), we keep the task group: orphan tasks are disallowed there by + # design, and cross-task cleanup is fundamentally incompatible with + # that model, so callers on trio still have to clean up LIFO. + if sniffio.current_async_library() == "asyncio": + bg_cm = _asyncio_background_tasks(stdout_reader, stdin_writer, read_stream_writer, write_stream) + else: + bg_cm = _anyio_task_group_background(stdout_reader, stdin_writer) + + async with bg_cm, process: try: yield read_stream, write_stream finally: - # MCP spec: stdio shutdown sequence - # 1. Close input stream to server - # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time - # 3. Send SIGKILL if still not exited - if process.stdin: # pragma: no branch - try: - await process.stdin.aclose() - except Exception: # pragma: no cover - # stdin might already be closed, which is fine - pass - - try: - # Give the process time to exit gracefully after stdin closes - with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): - await process.wait() - except TimeoutError: - # Process didn't exit from stdin closure, use platform-specific termination - # which handles SIGTERM -> SIGKILL escalation - await _terminate_process_tree(process) - except ProcessLookupError: # pragma: no cover - # Process already exited, which is fine - pass - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() + await _cleanup_process_and_streams() def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 06e2cba4b..3140c6d0b 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -558,3 +558,146 @@ def sigterm_handler(signum, frame): f"stdio_client cleanup took {elapsed:.1f} seconds for stdin-ignoring process. " f"Expected between 2-4 seconds (2s stdin timeout + termination time)." ) + + +# A stub MCP-ish server: exits cleanly as soon as stdin closes. We only need +# the stdio_client to be able to stand the transport up; we do not exercise +# any MCP protocol traffic in the FIFO-cleanup tests below. +_QUIET_STDIN_STUB = textwrap.dedent( + """ + import sys + for _ in sys.stdin: + pass + """ +) + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_stdio_client_supports_fifo_cleanup_on_asyncio(anyio_backend: str): + """Regression for https://github.com/modelcontextprotocol/python-sdk/issues/577. + + Prior to the fix, closing two ``stdio_client`` transports in the order + they were opened (FIFO) crashed with:: + + RuntimeError: Attempted to exit a cancel scope that isn't the + current task's current cancel scope + + because each stdio_client bound a task-group cancel scope to the + caller's task, and anyio enforces a strict LIFO stack on those + scopes. On asyncio the fix uses ``asyncio.create_task`` for the + reader / writer so the transports are independent and the cleanup + order no longer matters. + + (Trio intentionally forbids orphan tasks — there is no equivalent + fix on that backend, so this test is asyncio-only.) + """ + params = StdioServerParameters(command=sys.executable, args=["-c", _QUIET_STDIN_STUB]) + + s1 = AsyncExitStack() + s2 = AsyncExitStack() + + await s1.__aenter__() + await s2.__aenter__() + + await s1.enter_async_context(stdio_client(params)) + await s2.enter_async_context(stdio_client(params)) + + # Close in FIFO order — the opposite of what anyio's structured + # concurrency would normally require. + await s1.aclose() + await s2.aclose() + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_stdio_client_supports_lifo_cleanup_on_asyncio(anyio_backend: str): + """Sanity check for the fix above: the historical LIFO cleanup path + must still work unchanged on asyncio. + """ + params = StdioServerParameters(command=sys.executable, args=["-c", _QUIET_STDIN_STUB]) + + s1 = AsyncExitStack() + s2 = AsyncExitStack() + await s1.__aenter__() + await s2.__aenter__() + + await s1.enter_async_context(stdio_client(params)) + await s2.enter_async_context(stdio_client(params)) + + # LIFO cleanup — the last-opened transport closes first. + await s2.aclose() + await s1.aclose() + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_stdio_client_shared_exit_stack_fifo_on_asyncio(anyio_backend: str): + """The same story, but through a single ``AsyncExitStack`` that the + caller then closes once. ExitStack runs callbacks in LIFO order on + its own, so this case already worked — the test pins that behavior + so a future refactor of the asyncio branch cannot silently break + it. + """ + params = StdioServerParameters(command=sys.executable, args=["-c", _QUIET_STDIN_STUB]) + + async with AsyncExitStack() as stack: + await stack.enter_async_context(stdio_client(params)) + await stack.enter_async_context(stdio_client(params)) + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["trio"]) +async def test_stdio_client_supports_lifo_cleanup_on_trio(anyio_backend: str): + """Coverage for the structured-concurrency branch of ``stdio_client``. + + On trio the historical anyio task-group is kept — the FIFO fix from + #577 is asyncio-only because trio forbids orphan tasks by design. + This test just exercises the trio code path end-to-end with LIFO + cleanup (which works the same way it always has) so the trio + branch is not dead code under coverage. + """ + params = StdioServerParameters(command=sys.executable, args=["-c", _QUIET_STDIN_STUB]) + + async with AsyncExitStack() as stack: + await stack.enter_async_context(stdio_client(params)) + await stack.enter_async_context(stdio_client(params)) + + +@pytest.mark.anyio +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_stdio_client_reader_crash_propagates_on_asyncio( + anyio_backend: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Guardrail for the asyncio branch of #577: moving the reader/writer + out of an anyio task group must NOT silently swallow exceptions + they raise. An anyio task group would have re-raised those through + the async ``with`` block; the asyncio path has to reproduce that + contract by collecting the task results in ``__aexit__`` and + re-raising anything that isn't cancellation / closed-resource. + """ + from typing import Any + + from mcp.client import stdio as stdio_mod + + class _BoomTextStream: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + def __aiter__(self) -> "_BoomTextStream": + return self + + async def __anext__(self) -> str: + raise RuntimeError("deliberate reader crash for #577 regression test") + + monkeypatch.setattr(stdio_mod, "TextReceiveStream", _BoomTextStream) + + params = StdioServerParameters(command=sys.executable, args=["-c", _QUIET_STDIN_STUB]) + + with pytest.raises(RuntimeError, match="deliberate reader crash"): + async with stdio_client(params): + # Give the reader a chance to raise. The crash should close + # the streams out from under us, so we just wait a moment + # and then exit the context — the exception is surfaced on + # the way out. + await anyio.sleep(0.2)