diff --git a/middleware/proxy.go b/middleware/proxy.go index 5bf296f65..a40d58130 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -365,6 +365,9 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if err != nil { return config.ErrorHandler(c, err) } + if tgt == nil || tgt.URL == nil { + return config.ErrorHandler(c, echo.NewHTTPError(http.StatusBadGateway, "no proxy target available")) + } c.Set(config.ContextKey, tgt) diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 3a1310ef3..5494b23ba 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -457,6 +457,100 @@ func TestRoundRobinBalancerWithNoTargets(t *testing.T) { assert.NoError(t, err) } +func TestProxyWithNoTargetReturnsBadGateway(t *testing.T) { + targetURL, _ := url.Parse("http://127.0.0.1:8080") + target := &ProxyTarget{Name: "target", URL: targetURL} + emptyAfterRemove := NewRoundRobinBalancer([]*ProxyTarget{target}) + assert.True(t, emptyAfterRemove.RemoveTarget("target")) + + testCases := []struct { + name string + balancer ProxyBalancer + }{ + { + name: "random balancer with nil targets", + balancer: NewRandomBalancer(nil), + }, + { + name: "round-robin balancer with nil targets", + balancer: NewRoundRobinBalancer(nil), + }, + { + name: "round-robin balancer after removing last target", + balancer: emptyAfterRemove, + }, + { + name: "custom balancer with nil target", + balancer: &customBalancer{}, + }, + { + name: "custom balancer with nil target URL", + balancer: &customBalancer{target: &ProxyTarget{Name: "target"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + errorHandlerCalled := false + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: tc.balancer, + ErrorHandler: func(c *echo.Context, err error) error { + errorHandlerCalled = true + httpErr, ok := err.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return err + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + assert.NotPanics(t, func() { + e.ServeHTTP(rec, req) + }) + assert.True(t, errorHandlerCalled) + assert.Equal(t, http.StatusBadGateway, rec.Code) + }) + } +} + +func TestProxyWithNoTargetDoesNotRetry(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + targetURL, _ := url.Parse(server.URL) + + balancer := &sequenceBalancer{ + targets: []*ProxyTarget{ + nil, + {Name: "target", URL: targetURL}, + }, + } + + retryFilterCalled := false + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: balancer, + RetryCount: 1, + RetryFilter: func(c *echo.Context, err error) bool { + retryFilterCalled = true + return true + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.False(t, retryFilterCalled) + assert.Equal(t, 1, balancer.calls) + assert.Equal(t, http.StatusBadGateway, rec.Code) +} + func TestProxyRetries(t *testing.T) { newServer := func(res int) (*url.URL, *httptest.Server) { server := httptest.NewServer( @@ -788,6 +882,25 @@ func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) { return b.target, nil } +type sequenceBalancer struct { + targets []*ProxyTarget + calls int +} + +func (b *sequenceBalancer) AddTarget(target *ProxyTarget) bool { + return false +} + +func (b *sequenceBalancer) RemoveTarget(name string) bool { + return false +} + +func (b *sequenceBalancer) Next(c *echo.Context) (*ProxyTarget, error) { + target := b.targets[b.calls] + b.calls++ + return target, nil +} + func TestModifyResponseUseContext(t *testing.T) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {