diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 209d644d2..7c72f579c 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -75,11 +75,15 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can } else { - // If the status code is not success, fall back to SSE + // Streamable HTTP failed. Capture the underlying error (status + body) before falling back to SSE + // so that, if SSE also fails, we can surface the real Streamable HTTP diagnostic to the caller + // instead of dropping it on the floor (see https://github.com/modelcontextprotocol/csharp-sdk/issues/1526). LogStreamableHttpFailed(_name, response.StatusCode); + var streamableHttpError = await CreateStreamableHttpErrorAsync(response, cancellationToken).ConfigureAwait(false); + await streamableHttpTransport.DisposeAsync().ConfigureAwait(false); - await InitializeSseTransportAsync(message, cancellationToken).ConfigureAwait(false); + await InitializeSseTransportAsync(message, streamableHttpError, cancellationToken).ConfigureAwait(false); } } catch @@ -91,7 +95,7 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can } } - private async Task InitializeSseTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + private async Task InitializeSseTransportAsync(JsonRpcMessage message, HttpRequestException? streamableHttpError, CancellationToken cancellationToken) { if (_options.KnownSessionId is not null) { @@ -109,6 +113,20 @@ private async Task InitializeSseTransportAsync(JsonRpcMessage message, Cancellat LogUsingSSE(_name); ActiveTransport = sseTransport; } + catch (Exception sseError) when (streamableHttpError is not null && sseError is not OperationCanceledException) + { + // SSE fallback also failed. Surface the original Streamable HTTP error as the primary failure + // so the user sees the real server diagnostic (e.g. 415 Unsupported Media Type) instead of the + // unrelated SSE-fallback error (e.g. a 405 from a Streamable-HTTP-only server that doesn't accept GET). + await sseTransport.DisposeAsync().ConfigureAwait(false); + LogSseFallbackFailedAfterStreamableHttp(_name, sseError); + throw new AggregateException( + "Streamable HTTP transport failed and the SSE fallback also failed. " + + "The first inner exception is the original Streamable HTTP error (the real server response); " + + "the second is the SSE fallback failure.", + streamableHttpError, + sseError); + } catch { await sseTransport.DisposeAsync().ConfigureAwait(false); @@ -116,6 +134,31 @@ private async Task InitializeSseTransportAsync(JsonRpcMessage message, Cancellat } } + private static async Task CreateStreamableHttpErrorAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + // Best-effort read of the response body so the exception message includes the server's diagnostic + // (e.g. "Content-Type must be 'application/json'" on a 415). Mirrors EnsureSuccessStatusCodeWithResponseBodyAsync. + string? responseBody = null; + try + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + responseBody = await response.Content.ReadAsStringAsync(cts.Token).ConfigureAwait(false); + + const int MaxResponseBodyLength = 1024; + if (responseBody.Length > MaxResponseBodyLength) + { + responseBody = responseBody.Substring(0, MaxResponseBodyLength) + "..."; + } + } + catch + { + // Ignore all errors reading the response body (e.g., stream closed, timeout, cancellation) - we'll throw without it. + } + + return HttpResponseMessageExtensions.CreateHttpRequestException(response, responseBody); + } + public async ValueTask DisposeAsync() { try @@ -147,4 +190,7 @@ public async ValueTask DisposeAsync() [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} using SSE transport.")] private partial void LogUsingSSE(string endpointName); -} \ No newline at end of file + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} SSE fallback failed after Streamable HTTP also failed; surfacing both errors.")] + private partial void LogSseFallbackFailedAfterStreamableHttp(string endpointName, Exception sseError); +} diff --git a/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs index 768ebf7ea..71e38bb05 100644 --- a/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; using ModelContextProtocol.Tests.Utils; using System.Net; @@ -42,12 +43,12 @@ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() }; await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - + // The auto-detecting transport should be returned Assert.NotNull(session); } - [Fact] + [Fact] public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() { var options = new HttpClientTransportOptions @@ -102,8 +103,93 @@ public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() }; await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); - + // The auto-detecting transport should be returned Assert.NotNull(session); } -} \ No newline at end of file + + // Regression test for https://github.com/modelcontextprotocol/csharp-sdk/issues/1526 + // When Streamable HTTP returns 415 (e.g. wrong Content-Type) and the SSE fallback also fails + // (e.g. a Streamable-HTTP-only server returns 405 to the GET), the surfaced exception must + // preserve the original Streamable HTTP error rather than dropping it on the floor. + [Fact] + public async Task AutoDetectMode_PreservesOriginalError_WhenStreamableHttpReturns415AndSseFallbackFails() + { + var options = new HttpClientTransportOptions + { + Endpoint = new Uri("http://localhost"), + TransportMode = HttpTransportMode.AutoDetect, + Name = "AutoDetect test client" + }; + + using var mockHttpHandler = new MockHttpHandler(); + using var httpClient = new HttpClient(mockHttpHandler); + await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory); + + const string streamableHttpBody = "Content-Type must be 'application/json'"; + + mockHttpHandler.RequestHandler = (request) => + { + if (request.Method == HttpMethod.Post) + { + // Streamable HTTP fails with 415 - this is the real server diagnostic the user needs to see. + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.UnsupportedMediaType, + Content = new StringContent(streamableHttpBody), + }); + } + + if (request.Method == HttpMethod.Get) + { + // Streamable-HTTP-only server: SSE GET is rejected with 405. Without the fix this is the + // ONLY error the user ever sees, masking the real 415 diagnostic above. + return Task.FromResult(new HttpResponseMessage + { + StatusCode = HttpStatusCode.MethodNotAllowed, + Content = new StringContent("Method not allowed"), + }); + } + + throw new InvalidOperationException($"Unexpected request: {request.Method}"); + }; + + // ConnectAsync only constructs the AutoDetect transport; the probe runs on the first SendMessageAsync. + await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); + + var ex = await Assert.ThrowsAnyAsync(() => + session.SendMessageAsync( + new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }, + TestContext.Current.CancellationToken)); + + // Walk the exception chain and assert the original 415 (and its body) is somewhere in it. + // We don't pin the exact exception type so this stays robust to future error-shape tweaks, + // but the underlying status code and server body must reach the caller. + var combined = Flatten(ex); + Assert.Contains("415", combined); + Assert.Contains(streamableHttpBody, combined); + + static string Flatten(Exception e) + { + var sb = new System.Text.StringBuilder(); + void Walk(Exception? cur) + { + while (cur is not null) + { + sb.Append(cur.GetType().FullName).Append(": ").AppendLine(cur.Message); + if (cur is AggregateException agg) + { + foreach (var inner in agg.InnerExceptions) + { + Walk(inner); + } + return; + } + cur = cur.InnerException; + } + } + Walk(e); + return sb.ToString(); + } + } +}