diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index 9c6d6051..6b27ef6c 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -23,9 +23,12 @@ const ( // NotAllowedHandler handles cross domain not allowed requests. // At most one origin can be specified, other origins are ignored if given, default to be *. -func NotAllowedHandler(origins ...string) http.Handler { +func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { checkAndSetHeaders(w, r, origins) + if fn != nil { + fn(w) + } if r.Method != http.MethodOptions { w.WriteHeader(http.StatusNotFound) @@ -36,10 +39,13 @@ func NotAllowedHandler(origins ...string) http.Handler { } // Middleware returns a middleware that adds CORS headers to the response. -func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc { +func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { checkAndSetHeaders(w, r, origins) + if fn != nil { + fn(w) + } if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index 03052b29..0e112b93 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() - handler := NotAllowedHandler(test.origins...) + handler := NotAllowedHandler(nil, test.origins...) handler.ServeHTTP(w, r) if method == http.MethodOptions { assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) @@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) { } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) }) + t.Run(test.name+"-handler-custom", func(t *testing.T) { + r := httptest.NewRequest(method, "http://localhost", nil) + r.Header.Set(originHeader, test.reqOrigin) + w := httptest.NewRecorder() + handler := NotAllowedHandler(func(w http.ResponseWriter) { + w.Header().Set("foo", "bar") + }, test.origins...) + handler.ServeHTTP(w, r) + if method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + } else { + assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) + } + assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) + assert.Equal(t, "bar", w.Header().Get("foo")) + }) } } @@ -81,7 +97,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) { r := httptest.NewRequest(method, "http://localhost", nil) r.Header.Set(originHeader, test.reqOrigin) w := httptest.NewRecorder() - handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) { + handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler.ServeHTTP(w, r) @@ -92,6 +108,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) { } assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) }) + t.Run(test.name+"-middleware-custom", func(t *testing.T) { + r := httptest.NewRequest(method, "http://localhost", nil) + r.Header.Set(originHeader, test.reqOrigin) + w := httptest.NewRecorder() + handler := Middleware(func(w http.ResponseWriter) { + w.Header().Set("foo", "bar") + }, test.origins...)(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler.ServeHTTP(w, r) + if method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + } else { + assert.Equal(t, http.StatusOK, w.Result().StatusCode) + } + assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) + assert.Equal(t, "bar", w.Header().Get("foo")) + }) } } } diff --git a/rest/server.go b/rest/server.go index e847ca71..dd7d5b6b 100644 --- a/rest/server.go +++ b/rest/server.go @@ -99,8 +99,17 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { // WithCors returns a func to enable CORS for given origin, or default to all origins (*). func WithCors(origin ...string) RunOption { return func(server *Server) { - server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...)) - server.Use(cors.Middleware(origin...)) + server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...)) + server.Use(cors.Middleware(nil, origin...)) + } +} + +// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), +// fn lets caller customizing the response. +func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption { + return func(server *Server) { + server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...)) + server.Use(cors.Middleware(fn, origin...)) } } diff --git a/rest/server_test.go b/rest/server_test.go index da164c41..1f36f916 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -310,3 +310,20 @@ Port: 54321 opt := WithCors("local") opt(srv) } + +func TestWithCustomCors(t *testing.T) { + const configYaml = ` +Name: foo +Port: 54321 +` + var cnf RestConf + assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf)) + rt := router.NewRouter() + srv, err := NewServer(cnf, WithRouter(rt)) + assert.Nil(t, err) + + opt := WithCustomCors(func(w http.ResponseWriter) { + w.Header().Set("foo", "bar") + }, "local") + opt(srv) +}