diff --git a/mcp/server.go b/mcp/server.go index c13955c5..a0563211 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1028,6 +1028,13 @@ func (s *Server) Connect(ctx context.Context, t Transport, opts *ServerSessionOp s.opts.Logger.Error("server connect error", "error", err) return nil, err } + + // Start keepalive before returning the session to avoid race conditions with Close. + // This is safe because the spec allows sending pings before initialization (see ServerSession.handle for details). + if s.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) + } + return ss, nil } @@ -1055,9 +1062,6 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar ss.server.opts.Logger.Error("duplicate initialized notification") return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } - if ss.server.opts.KeepAlive > 0 { - ss.startKeepalive(ss.server.opts.KeepAlive) - } if h := ss.server.opts.InitializedHandler; h != nil { h(ctx, serverRequestFor(ss, params)) } @@ -1107,7 +1111,7 @@ type ServerSession struct { server *Server conn *jsonrpc2.Connection mcpConn Connection - keepaliveCancel context.CancelFunc // TODO: theory around why keepaliveCancel need not be guarded + keepaliveCancel context.CancelFunc mu sync.Mutex state ServerSessionState diff --git a/mcp/server_test.go b/mcp/server_test.go index e57af1e2..07f6404d 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -508,54 +508,6 @@ func TestServerAddResourceTemplate(t *testing.T) { } } -// TestServerSessionkeepaliveCancelOverwritten is to verify that `ServerSession.keepaliveCancel` is assigned exactly once, -// ensuring that only a single goroutine is responsible for the session's keepalive ping mechanism. -func TestServerSessionkeepaliveCancelOverwritten(t *testing.T) { - // Set KeepAlive to a long duration to ensure the keepalive - // goroutine stays alive for the duration of the test without actually sending - // ping requests, since we don't have a real client connection established. - server := NewServer(testImpl, &ServerOptions{KeepAlive: 5 * time.Second}) - ss := &ServerSession{server: server} - - // 1. Initialize the session. - _, err := ss.initialize(context.Background(), &InitializeParams{}) - if err != nil { - t.Fatalf("ServerSession initialize failed: %v", err) - } - - // 2. Call 'initialized' for the first time. This should start the keepalive mechanism. - _, err = ss.initialized(context.Background(), &InitializedParams{}) - if err != nil { - t.Fatalf("First initialized call failed: %v", err) - } - if ss.keepaliveCancel == nil { - t.Fatalf("expected ServerSession.keepaliveCancel to be set after the first call of initialized") - } - - // Save the cancel function and use defer to ensure resources are cleaned up. - firstCancel := ss.keepaliveCancel - defer firstCancel() - - // 3. Manually set the field to nil. - // Do this to facilitate the test's core assertion. The goal is to verify that - // 'ss.keepaliveCancel' is not assigned a second time. By setting it to nil, - // we can easily check after the next call if a new keepalive goroutine was started. - ss.keepaliveCancel = nil - - // 4. Call 'initialized' for the second time. This should return an error. - _, err = ss.initialized(context.Background(), &InitializedParams{}) - if err == nil { - t.Fatalf("Expected 'duplicate initialized received' error on second call, got nil") - } - - // 5. Re-check the field to ensure it remains nil. - // Since 'initialized' correctly returned an error and did not call - // 'startKeepalive', the field should remain unchanged. - if ss.keepaliveCancel != nil { - t.Fatal("expected ServerSession.keepaliveCancel to be nil after we manually niled it and re-initialized") - } -} - // panicks reports whether f() panics. func panics(f func()) (b bool) { defer func() { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d1dc482a..7a6c733d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -367,6 +367,34 @@ func TestStreamableServerShutdown(t *testing.T) { } } +// TestStreamableStatelessKeepaliveRace verifies that there is no data race between +// ServerSession.startKeepalive and ServerSession.Close in stateless servers. +func TestStreamableStatelessKeepaliveRace(t *testing.T) { + ctx := context.Background() + server := NewServer(testImpl, &ServerOptions{KeepAlive: time.Hour}) + AddTool(server, &Tool{Name: "greet"}, sayHi) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + for range 50 { + cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + }, nil) + if err != nil { + t.Fatalf("NewClient() failed: %v", err) + } + _, _ = cs.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "world"}, + }) + _ = cs.Close() + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network