feat: add rest.WithCustomCors to let caller customize the response (#1274)

This commit is contained in:
Kevin Wan 2021-11-25 23:03:37 +08:00 committed by GitHub
parent 86f9f63b46
commit 0395ba1816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 6 deletions

View File

@ -23,9 +23,12 @@ const (
// NotAllowedHandler handles cross domain not allowed requests.
// At most one origin can be specified, other origins are ignored if given, default to be *.
func NotAllowedHandler(origins ...string) http.Handler {
func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
checkAndSetHeaders(w, r, origins)
if fn != nil {
fn(w)
}
if r.Method != http.MethodOptions {
w.WriteHeader(http.StatusNotFound)
@ -36,10 +39,13 @@ func NotAllowedHandler(origins ...string) http.Handler {
}
// Middleware returns a middleware that adds CORS headers to the response.
func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc {
func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
checkAndSetHeaders(w, r, origins)
if fn != nil {
fn(w)
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)

View File

@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := NotAllowedHandler(test.origins...)
handler := NotAllowedHandler(nil, test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
t.Run(test.name+"-handler-custom", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := NotAllowedHandler(func(w http.ResponseWriter) {
w.Header().Set("foo", "bar")
}, test.origins...)
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
assert.Equal(t, "bar", w.Header().Get("foo"))
})
}
}
@ -81,7 +97,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler.ServeHTTP(w, r)
@ -92,6 +108,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
t.Run(test.name+"-middleware-custom", func(t *testing.T) {
r := httptest.NewRequest(method, "http://localhost", nil)
r.Header.Set(originHeader, test.reqOrigin)
w := httptest.NewRecorder()
handler := Middleware(func(w http.ResponseWriter) {
w.Header().Set("foo", "bar")
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler.ServeHTTP(w, r)
if method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
} else {
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
}
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
assert.Equal(t, "bar", w.Header().Get("foo"))
})
}
}
}

View File

@ -99,8 +99,17 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
func WithCors(origin ...string) RunOption {
return func(server *Server) {
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...))
server.Use(cors.Middleware(origin...))
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
server.Use(cors.Middleware(nil, origin...))
}
}
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
// fn lets caller customizing the response.
func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption {
return func(server *Server) {
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...))
server.Use(cors.Middleware(fn, origin...))
}
}

View File

@ -310,3 +310,20 @@ Port: 54321
opt := WithCors("local")
opt(srv)
}
func TestWithCustomCors(t *testing.T) {
const configYaml = `
Name: foo
Port: 54321
`
var cnf RestConf
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
rt := router.NewRouter()
srv, err := NewServer(cnf, WithRouter(rt))
assert.Nil(t, err)
opt := WithCustomCors(func(w http.ResponseWriter) {
w.Header().Set("foo", "bar")
}, "local")
opt(srv)
}