go-zero/rest/router/patrouter.go

128 lines
2.7 KiB
Go
Raw Normal View History

2020-07-29 18:00:04 +08:00
package router
2020-07-26 17:09:05 +08:00
import (
2020-07-31 12:13:30 +08:00
"errors"
2020-07-26 17:09:05 +08:00
"net/http"
"path"
"strings"
2020-08-08 16:40:10 +08:00
"github.com/tal-tech/go-zero/core/search"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/internal/context"
2020-07-26 17:09:05 +08:00
)
const (
allowHeader = "Allow"
allowMethodSeparator = ", "
)
2020-07-31 12:13:30 +08:00
var (
2021-03-01 19:15:35 +08:00
// ErrInvalidMethod is an error that indicates not a valid http method.
2020-07-31 12:13:30 +08:00
ErrInvalidMethod = errors.New("not a valid http method")
2021-03-01 19:15:35 +08:00
// ErrInvalidPath is an error that indicates path is not start with /.
ErrInvalidPath = errors.New("path must begin with '/'")
2020-07-31 12:13:30 +08:00
)
2020-10-20 14:23:21 +08:00
type patRouter struct {
2020-10-21 14:10:29 +08:00
trees map[string]*search.Tree
notFound http.Handler
notAllowed http.Handler
2020-07-26 17:09:05 +08:00
}
2021-03-01 19:15:35 +08:00
// NewRouter returns a httpx.Router.
2020-10-20 14:23:21 +08:00
func NewRouter() httpx.Router {
return &patRouter{
2020-07-26 17:09:05 +08:00
trees: make(map[string]*search.Tree),
}
}
2020-10-20 14:23:21 +08:00
func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
2020-07-26 17:09:05 +08:00
if !validMethod(method) {
return ErrInvalidMethod
}
if len(reqPath) == 0 || reqPath[0] != '/' {
return ErrInvalidPath
}
cleanPath := path.Clean(reqPath)
2021-02-09 13:50:21 +08:00
tree, ok := pr.trees[method]
if ok {
2020-07-26 17:09:05 +08:00
return tree.Add(cleanPath, handler)
}
2021-02-09 13:50:21 +08:00
tree = search.NewTree()
pr.trees[method] = tree
return tree.Add(cleanPath, handler)
2020-07-26 17:09:05 +08:00
}
2020-10-20 14:23:21 +08:00
func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
2020-07-26 17:09:05 +08:00
reqPath := path.Clean(r.URL.Path)
if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok {
if len(result.Params) > 0 {
2020-07-29 18:00:04 +08:00
r = context.WithPathVars(r, result.Params)
2020-07-26 17:09:05 +08:00
}
result.Item.(http.Handler).ServeHTTP(w, r)
return
}
}
allows, ok := pr.methodsAllowed(r.Method, reqPath)
2020-10-21 14:10:29 +08:00
if !ok {
pr.handleNotFound(w, r)
return
}
if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allows)
2020-07-26 17:09:05 +08:00
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
2020-10-20 14:23:21 +08:00
func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
2020-07-26 17:09:05 +08:00
pr.notFound = handler
}
2020-10-21 14:10:29 +08:00
func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
pr.notAllowed = handler
}
2020-10-20 14:23:21 +08:00
func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
2020-07-26 17:09:05 +08:00
if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
}
func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
2020-07-26 17:09:05 +08:00
var allows []string
for treeMethod, tree := range pr.trees {
if treeMethod == method {
continue
}
_, ok := tree.Search(path)
if ok {
allows = append(allows, treeMethod)
}
}
if len(allows) > 0 {
return strings.Join(allows, allowMethodSeparator), true
}
2021-02-09 13:50:21 +08:00
return "", false
2020-07-26 17:09:05 +08:00
}
func validMethod(method string) bool {
return method == http.MethodDelete || method == http.MethodGet ||
method == http.MethodHead || method == http.MethodOptions ||
method == http.MethodPatch || method == http.MethodPost ||
method == http.MethodPut
}