From 5e63e2c874703b471c67cfdbf13b89f764c2606b Mon Sep 17 00:00:00 2001 From: puneetdixit200 <236133619+puneetdixit200@users.noreply.github.com> Date: Fri, 22 May 2026 11:05:05 +0530 Subject: [PATCH] Run middleware for OPTIONS fallbacks --- middleware/cors_test.go | 29 +++++++++++++++++++++ router.go | 56 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 5de4ca063..35337ea45 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -581,6 +581,35 @@ func TestCorsHeaders(t *testing.T) { } } +func TestCORSWithConfig_GroupPreflightWithoutOptionsRoute(t *testing.T) { + e := echo.New() + g := e.Group("/myroute", CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"https://example.com"}, + AllowHeaders: []string{ + echo.HeaderOrigin, + echo.HeaderContentType, + echo.HeaderAccept, + echo.HeaderAuthorization, + }, + })) + g.GET("", func(c *echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodOptions, "/myroute", nil) + req.Header.Set(echo.HeaderOrigin, "https://example.com") + req.Header.Set(echo.HeaderAccessControlRequestMethod, http.MethodGet) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "https://example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + assert.Equal(t, "Origin,Content-Type,Accept,Authorization", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(echo.HeaderAllow)) +} + func Test_allowOriginFunc(t *testing.T) { returnTrue := func(c *echo.Context, origin string) (string, bool, error) { return origin, true, nil diff --git a/router.go b/router.go index 68802e062..268834c66 100644 --- a/router.go +++ b/router.go @@ -142,6 +142,7 @@ const ( type routeMethod struct { *RouteInfo handler HandlerFunc + middlewares []MiddlewareFunc orgRouteInfo RouteInfo } @@ -298,6 +299,54 @@ func (m *routeMethods) updateAllowHeader() { m.allowHeader = buf.String() } +func (m *routeMethods) optionsFallbackHandler(requestedMethod string) *routeMethod { + if requestedMethod != "" { + if h := m.find(requestedMethod, true); h != nil { + return h + } + } + if m.connect != nil { + return m.connect + } + if m.delete != nil { + return m.delete + } + if m.get != nil { + return m.get + } + if m.head != nil { + return m.head + } + if m.options != nil { + return m.options + } + if m.patch != nil { + return m.patch + } + if m.post != nil { + return m.post + } + if m.propfind != nil { + return m.propfind + } + if m.put != nil { + return m.put + } + if m.trace != nil { + return m.trace + } + if m.report != nil { + return m.report + } + if m.any != nil { + return m.any + } + for _, r := range m.anyOther { + return r + } + return nil +} + func (m *routeMethods) isHandler() bool { return m.get != nil || m.post != nil || @@ -488,6 +537,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) { rm := routeMethod{ RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name}, handler: h, + middlewares: append([]MiddlewareFunc(nil), route.Middlewares...), orgRouteInfo: ri, } r.insert(paramKind, path[:i], method, rm) @@ -503,6 +553,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) { rm := routeMethod{ RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name}, handler: h, + middlewares: append([]MiddlewareFunc(nil), route.Middlewares...), orgRouteInfo: ri, } r.insert(anyKind, path[:i+1], method, rm) @@ -516,6 +567,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) { rm := routeMethod{ RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name}, handler: h, + middlewares: append([]MiddlewareFunc(nil), route.Middlewares...), orgRouteInfo: ri, } r.insert(staticKind, path, method, rm) @@ -1011,6 +1063,10 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { rHandler = r.methodNotAllowedHandler if req.Method == http.MethodOptions { rHandler = r.optionsMethodHandler + requestedMethod := req.Header.Get(HeaderAccessControlRequestMethod) + if fallbackMethod := currentNode.methods.optionsFallbackHandler(requestedMethod); fallbackMethod != nil { + rHandler = applyMiddleware(rHandler, fallbackMethod.middlewares...) + } } } }