mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
feat: add rest.WithCustomCors to let caller customize the response (#1274)
This commit is contained in:
parent
86f9f63b46
commit
0395ba1816
@ -23,9 +23,12 @@ const (
|
|||||||
|
|
||||||
// NotAllowedHandler handles cross domain not allowed requests.
|
// NotAllowedHandler handles cross domain not allowed requests.
|
||||||
// At most one origin can be specified, other origins are ignored if given, default to be *.
|
// At most one origin can be specified, other origins are ignored if given, default to be *.
|
||||||
func NotAllowedHandler(origins ...string) http.Handler {
|
func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
checkAndSetHeaders(w, r, origins)
|
checkAndSetHeaders(w, r, origins)
|
||||||
|
if fn != nil {
|
||||||
|
fn(w)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodOptions {
|
if r.Method != http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
@ -36,10 +39,13 @@ func NotAllowedHandler(origins ...string) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns a middleware that adds CORS headers to the response.
|
// Middleware returns a middleware that adds CORS headers to the response.
|
||||||
func Middleware(origins ...string) func(http.HandlerFunc) http.HandlerFunc {
|
func Middleware(fn func(w http.ResponseWriter), origins ...string) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
checkAndSetHeaders(w, r, origins)
|
checkAndSetHeaders(w, r, origins)
|
||||||
|
if fn != nil {
|
||||||
|
fn(w)
|
||||||
|
}
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
@ -62,7 +62,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
r.Header.Set(originHeader, test.reqOrigin)
|
r.Header.Set(originHeader, test.reqOrigin)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler := NotAllowedHandler(test.origins...)
|
handler := NotAllowedHandler(nil, test.origins...)
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
if method == http.MethodOptions {
|
if method == http.MethodOptions {
|
||||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
@ -71,6 +71,22 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
})
|
})
|
||||||
|
t.Run(test.name+"-handler-custom", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
|
r.Header.Set(originHeader, test.reqOrigin)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := NotAllowedHandler(func(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("foo", "bar")
|
||||||
|
}, test.origins...)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
|
assert.Equal(t, "bar", w.Header().Get("foo"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +97,7 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
r := httptest.NewRequest(method, "http://localhost", nil)
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
r.Header.Set(originHeader, test.reqOrigin)
|
r.Header.Set(originHeader, test.reqOrigin)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
@ -92,6 +108,24 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
})
|
})
|
||||||
|
t.Run(test.name+"-middleware-custom", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
|
r.Header.Set(originHeader, test.reqOrigin)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := Middleware(func(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("foo", "bar")
|
||||||
|
}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
|
assert.Equal(t, "bar", w.Header().Get("foo"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,8 +99,17 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
|||||||
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
||||||
func WithCors(origin ...string) RunOption {
|
func WithCors(origin ...string) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(origin...))
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
|
||||||
server.Use(cors.Middleware(origin...))
|
server.Use(cors.Middleware(nil, origin...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
|
||||||
|
// fn lets caller customizing the response.
|
||||||
|
func WithCustomCors(fn func(http.ResponseWriter), origin ...string) RunOption {
|
||||||
|
return func(server *Server) {
|
||||||
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(fn, origin...))
|
||||||
|
server.Use(cors.Middleware(fn, origin...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,3 +310,20 @@ Port: 54321
|
|||||||
opt := WithCors("local")
|
opt := WithCors("local")
|
||||||
opt(srv)
|
opt(srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCustomCors(t *testing.T) {
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Port: 54321
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
rt := router.NewRouter()
|
||||||
|
srv, err := NewServer(cnf, WithRouter(rt))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
opt := WithCustomCors(func(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("foo", "bar")
|
||||||
|
}, "local")
|
||||||
|
opt(srv)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user