From 2588a36555dc366274f374c1f74c9bcbdc46dcf8 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Tue, 30 Jul 2024 17:29:44 +0800 Subject: [PATCH] feat: support rest.WithCorsHeaders to customize cors headers (#4284) --- rest/internal/cors/handlers.go | 5 +++ rest/internal/cors/handlers_test.go | 69 +++++++++++++++++++++++++++++ rest/server.go | 12 +++++ rest/server_test.go | 58 ++++++++++++++++++++++++ 4 files changed, 144 insertions(+) diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index 133b47dd..d0e7c9f7 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -26,6 +26,11 @@ const ( originHeader = "Origin" ) +// AddAllowHeaders sets the allowed headers. +func AddAllowHeaders(header http.Header, headers ...string) { + header.Add(allowHeaders, strings.Join(headers, ", ")) +} + // NotAllowedHandler handles cross domain not allowed requests. // At most one origin can be specified, other origins are ignored if given, default to be *. func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index c9de97a3..b8679f51 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -3,11 +3,80 @@ package cors import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" ) +func TestAddAllowHeaders(t *testing.T) { + tests := []struct { + name string + initial string + headers []string + expected string + }{ + { + name: "single header", + initial: "", + headers: []string{"Content-Type"}, + expected: "Content-Type", + }, + { + name: "multiple headers", + initial: "", + headers: []string{"Content-Type", "Authorization", "X-Requested-With"}, + expected: "Content-Type, Authorization, X-Requested-With", + }, + { + name: "add to existing headers", + initial: "Origin, Accept", + headers: []string{"Content-Type"}, + expected: "Origin, Accept, Content-Type", + }, + { + name: "no headers", + initial: "", + headers: []string{}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + headers := make(map[string]struct{}) + if tt.initial != "" { + header.Set(allowHeaders, tt.initial) + vals := strings.Split(tt.initial, ", ") + for _, v := range vals { + headers[v] = struct{}{} + } + } + for _, h := range tt.headers { + headers[h] = struct{}{} + } + AddAllowHeaders(header, tt.headers...) + var actual []string + vals := header.Values(allowHeaders) + for _, v := range vals { + bunch := strings.Split(v, ", ") + for _, b := range bunch { + if len(b) > 0 { + actual = append(actual, b) + } + } + } + + var expect []string + for k := range headers { + expect = append(expect, k) + } + assert.ElementsMatch(t, expect, actual) + }) + } +} + func TestCorsHandlerWithOrigins(t *testing.T) { tests := []struct { name string diff --git a/rest/server.go b/rest/server.go index 747bb2c3..b1e5487b 100644 --- a/rest/server.go +++ b/rest/server.go @@ -161,6 +161,18 @@ func WithCors(origin ...string) RunOption { } } +// WithCorsHeaders returns a RunOption to enable CORS with given headers. +func WithCorsHeaders(headers ...string) RunOption { + const allDomains = "*" + + return func(server *Server) { + server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, allDomains)) + server.router = newCorsRouter(server.router, func(header http.Header) { + cors.AddAllowHeaders(header, headers...) + }, allDomains) + } +} + // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), // fn lets caller customizing the response. func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter), diff --git a/rest/server_test.go b/rest/server_test.go index 3f01fd3f..9a92d58f 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -420,6 +420,64 @@ Port: 54321 opt(svr) } +func TestWithCorsHeaders(t *testing.T) { + tests := []struct { + name string + headers []string + }{ + { + name: "single header", + headers: []string{"UserHeader"}, + }, + { + name: "multiple headers", + headers: []string{"UserHeader", "X-Requested-With"}, + }, + { + name: "no headers", + headers: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + const configYaml = ` +Name: foo +Port: 54321 +` + var cnf RestConf + assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) + rt := router.NewRouter() + svr, err := NewServer(cnf, WithRouter(rt)) + assert.Nil(t, err) + defer svr.Stop() + option := WithCorsHeaders(tt.headers...) + option(svr) + + // Assuming newCorsRouter sets headers correctly, + // we would need to verify the behavior here. Since we don't have + // direct access to headers, we'll mock newCorsRouter to capture it. + w := httptest.NewRecorder() + svr.ServeHTTP(w, httptest.NewRequest(http.MethodOptions, "/", nil)) + + vals := w.Header().Values("Access-Control-Allow-Headers") + respHeaders := make(map[string]struct{}) + for _, header := range vals { + headers := strings.Split(header, ", ") + for _, h := range headers { + if len(h) > 0 { + respHeaders[h] = struct{}{} + } + } + } + for _, h := range tt.headers { + _, ok := respHeaders[h] + assert.Truef(t, ok, "expected header %s not found", h) + } + }) + } +} + func TestServer_PrintRoutes(t *testing.T) { const ( configYaml = `