support cors in rest server

This commit is contained in:
kevin 2020-10-21 14:10:29 +08:00
parent 1c1e4bca86
commit fe0d0687f5
8 changed files with 122 additions and 8 deletions

View File

@ -56,7 +56,7 @@ func main() {
Port: *port,
Timeout: *timeout,
MaxConns: 500,
})
}, rest.WithNotAllowedHandler(rest.CorsHandler()))
defer engine.Stop()
engine.Use(first)

29
rest/handlers.go Normal file
View File

@ -0,0 +1,29 @@
package rest
import (
"net/http"
"strings"
)
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigin = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
separator = ", "
)
func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, strings.Join(origins, separator))
} else {
w.Header().Set(allowOrigin, allOrigin)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)
w.WriteHeader(http.StatusNoContent)
})
}

27
rest/handlers_test.go Normal file
View File

@ -0,0 +1,27 @@
package rest
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandler(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler()
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, allOrigin, w.Header().Get(allowOrigin))
}
func TestCorsHandlerWithOrigins(t *testing.T) {
origins := []string{"local", "remote"}
w := httptest.NewRecorder()
handler := CorsHandler(origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, strings.Join(origins, separator), w.Header().Get(allowOrigin))
}

View File

@ -6,4 +6,5 @@ type Router interface {
http.Handler
Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler)
SetNotAllowedHandler(handler http.Handler)
}

View File

@ -22,8 +22,9 @@ var (
)
type patRouter struct {
trees map[string]*search.Tree
notFound http.Handler
trees map[string]*search.Tree
notFound http.Handler
notAllowed http.Handler
}
func NewRouter() httpx.Router {
@ -63,11 +64,17 @@ func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
allow, ok := pr.methodNotAllowed(r.Method, reqPath)
if !ok {
pr.handleNotFound(w, r)
return
}
if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
}
}
@ -75,6 +82,10 @@ func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler
}
func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
pr.notAllowed = handler
}
func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)

View File

@ -60,13 +60,30 @@ func TestPatRouterNotFound(t *testing.T) {
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true
}))
router.Handle(http.MethodGet, "/a/b", nil)
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notFound)
}
func TestPatRouterNotAllowed(t *testing.T) {
var notAllowed bool
router := NewRouter()
router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notAllowed = true
}))
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notAllowed)
}
func TestPatRouter(t *testing.T) {
tests := []struct {
method string

View File

@ -1,12 +1,14 @@
package rest
import (
"errors"
"log"
"net/http"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router"
)
type (
@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
}
func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil {
return nil, err
}
@ -125,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
return routes
}
func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotFoundHandler(handler)
return WithRouter(rt)
}
func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotAllowedHandler(handler)
return WithRouter(rt)
}
func WithPriority() RouteOption {
return func(r *featuredRoutes) {
r.priority = true

View File

@ -12,6 +12,11 @@ import (
"github.com/tal-tech/go-zero/rest/router"
)
func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
assert.NotNil(t, err)
}
func TestWithMiddleware(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
@ -69,7 +74,7 @@ func TestWithMiddleware(t *testing.T) {
}, m)
}
func TestMultiMiddleware(t *testing.T) {
func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
@ -140,3 +145,9 @@ func TestMultiMiddleware(t *testing.T) {
"whatever": "200000200000",
}, m)
}
func TestWithPriority(t *testing.T) {
var fr featuredRoutes
WithPriority()(&fr)
assert.True(t, fr.priority)
}