mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
refactor
This commit is contained in:
parent
121323b8c3
commit
ca3934582a
@ -8,11 +8,11 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/executors"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/example/graceful/dns/api/svc"
|
||||
"zero/example/graceful/dns/api/types"
|
||||
"zero/example/graceful/dns/rpc/graceful"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc {
|
||||
|
@ -8,11 +8,11 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/executors"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/example/graceful/etcd/api/svc"
|
||||
"zero/example/graceful/etcd/api/types"
|
||||
"zero/example/graceful/etcd/rpc/graceful"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
func gracefulHandler(ctx *svc.ServiceContext) http.HandlerFunc {
|
||||
|
@ -4,10 +4,10 @@ import (
|
||||
"flag"
|
||||
"net/http"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/service"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -5,10 +5,10 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/service"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -6,10 +6,10 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/service"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -5,10 +5,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/service"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
var keyPem = flag.String("prikey", "private.pem", "the private key file")
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/conf"
|
||||
"zero/core/httpx"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/dgrijalva/jwt-go/request"
|
||||
|
@ -6,11 +6,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/conf"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/service"
|
||||
"zero/example/tracing/remote/portal"
|
||||
"zero/ngin"
|
||||
"zero/ngin/httpx"
|
||||
"zero/rpcx"
|
||||
)
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
"zero/core/httpsecurity"
|
||||
"zero/core/logx"
|
||||
"zero/ngin/internal"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
@ -37,7 +37,7 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
|
||||
opt(&authOpts)
|
||||
}
|
||||
|
||||
parser := httpsecurity.NewTokenParser()
|
||||
parser := internal.NewTokenParser()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token, err := parser.ParseToken(r, secret, authOpts.PrevSecret)
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -6,10 +6,10 @@ import (
|
||||
"strings"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
"zero/ngin/internal"
|
||||
"zero/ngin/internal/security"
|
||||
)
|
||||
|
||||
const breakerSeparator = "://"
|
||||
@ -22,12 +22,12 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
|
||||
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code < http.StatusInternalServerError {
|
||||
promise.Accept()
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
@ -1,13 +1,13 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
"zero/ngin/httpx"
|
||||
"zero/ngin/internal/security"
|
||||
)
|
||||
|
||||
const contentSecurity = "X-Content-Security"
|
||||
@ -24,12 +24,12 @@ func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut:
|
||||
header, err := internal.ParseContentSecurity(decrypters, r)
|
||||
header, err := security.ParseContentSecurity(decrypters, r)
|
||||
if err != nil {
|
||||
logx.Infof("Signature parse failed, X-Content-Security: %s, error: %s",
|
||||
r.Header.Get(contentSecurity), err.Error())
|
||||
executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks)
|
||||
} else if code := internal.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
|
||||
} else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass {
|
||||
logx.Infof("Signature verification failed, X-Content-Security: %s",
|
||||
r.Header.Get(contentSecurity))
|
||||
executeCallbacks(w, r, next, strict, code, callbacks)
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -18,7 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
"zero/ngin/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
@ -1,11 +1,11 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
const gzipEncoding = "gzip"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
"zero/ngin/httpx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -9,12 +9,11 @@ import (
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/core/httpx"
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
"zero/core/utils"
|
||||
"zero/ngin/internal"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
@ -41,7 +40,7 @@ func (w *LoggedResponseWriter) WriteHeader(code int) {
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
timer := utils.NewElapsedTimer()
|
||||
logs := new(httplog.LogCollector)
|
||||
logs := new(internal.LogCollector)
|
||||
lrw := LoggedResponseWriter{
|
||||
w: w,
|
||||
r: r,
|
||||
@ -50,7 +49,7 @@ func LogHandler(next http.Handler) http.Handler {
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
|
||||
next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logBrief(r, lrw.code, timer, logs)
|
||||
})
|
||||
@ -93,8 +92,8 @@ func DetailedLogHandler(next http.Handler) http.Handler {
|
||||
|
||||
var dup io.ReadCloser
|
||||
r.Body, dup = iox.DupReadCloser(r.Body)
|
||||
logs := new(httplog.LogCollector)
|
||||
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), httplog.LogContext, logs)))
|
||||
logs := new(internal.LogCollector)
|
||||
next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs)))
|
||||
r.Body = dup
|
||||
logDetails(r, lrw, timer, logs)
|
||||
})
|
||||
@ -109,14 +108,14 @@ func dumpRequest(r *http.Request) string {
|
||||
}
|
||||
}
|
||||
|
||||
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplog.LogCollector) {
|
||||
func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *internal.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s - %s - %s",
|
||||
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
|
||||
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)))
|
||||
if duration > slowThreshold {
|
||||
logx.Slowf("[HTTP] %d - %s - %s - %s - slowcall(%s)",
|
||||
code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
|
||||
code, r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration))
|
||||
}
|
||||
|
||||
ok := isOkResponse(code)
|
||||
@ -137,7 +136,7 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *httplo
|
||||
}
|
||||
|
||||
func logDetails(r *http.Request, response *DetailLoggedResponseWriter, timer *utils.ElapsedTimer,
|
||||
logs *httplog.LogCollector) {
|
||||
logs *internal.LogCollector) {
|
||||
var buf bytes.Buffer
|
||||
duration := timer.Duration()
|
||||
buf.WriteString(fmt.Sprintf("%d - %s - %s\n=> %s\n",
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
@ -8,9 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"zero/ngin/internal"
|
||||
|
||||
"zero/core/httplog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -26,7 +26,7 @@ func TestLogHandler(t *testing.T) {
|
||||
for _, logHandler := range handlers {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Context().Value(httplog.LogContext).(*httplog.LogCollector).Append("anything")
|
||||
r.Context().Value(internal.LogContext).(*internal.LogCollector).Append("anything")
|
||||
w.Header().Set("X-Test", "test")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := w.Write([]byte("content"))
|
@ -1,9 +1,9 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/ngin/internal"
|
||||
)
|
||||
|
||||
func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
|
||||
@ -16,7 +16,7 @@ func MaxBytesHandler(n int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > n {
|
||||
httplog.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
|
||||
internal.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d",
|
||||
n, r.ContentLength, http.StatusRequestEntityTooLarge)
|
||||
w.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
} else {
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
@ -1,11 +1,11 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/core/logx"
|
||||
"zero/core/syncx"
|
||||
"zero/ngin/internal"
|
||||
)
|
||||
|
||||
func MaxConns(n int) func(http.Handler) http.Handler {
|
||||
@ -28,7 +28,7 @@ func MaxConns(n int) func(http.Handler) http.Handler {
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
httplog.Errorf(r, "Concurrent connections over %d, rejected with code %d",
|
||||
internal.Errorf(r, "Concurrent connections over %d, rejected with code %d",
|
||||
n, http.StatusServiceUnavailable)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,13 +1,13 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/metric"
|
||||
"zero/core/timex"
|
||||
"zero/ngin/internal/security"
|
||||
)
|
||||
|
||||
const serverNamespace = "http_server"
|
||||
@ -35,7 +35,7 @@ func PromMetricHandler(path string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
|
||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,18 +1,18 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"zero/core/httplog"
|
||||
"zero/ngin/internal"
|
||||
)
|
||||
|
||||
func RecoverHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if result := recover(); result != nil {
|
||||
httplog.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
|
||||
internal.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack()))
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
@ -1,14 +1,14 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/httphandler/internal"
|
||||
"zero/core/httpx"
|
||||
"zero/core/load"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
"zero/ngin/internal"
|
||||
"zero/ngin/internal/security"
|
||||
)
|
||||
|
||||
const serviceType = "api"
|
||||
@ -35,12 +35,12 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
|
||||
metrics.AddDrop()
|
||||
sheddingStat.IncrementDrop()
|
||||
logx.Errorf("[http] dropped, %s - %s - %s",
|
||||
r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
|
||||
r.RequestURI, internal.GetRemoteAddr(r), r.UserAgent())
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
cw := &internal.WithCodeResponseWriter{Writer: w}
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
defer func() {
|
||||
if cw.Code == http.StatusServiceUnavailable {
|
||||
promise.Fail()
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,4 +1,4 @@
|
||||
package httphandler
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,18 +1,16 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"zero/core/httprouter"
|
||||
"zero/core/mapping"
|
||||
"zero/ngin/internal/context"
|
||||
)
|
||||
|
||||
const (
|
||||
multipartFormData = "multipart/form-data"
|
||||
xForwardFor = "X-Forward-For"
|
||||
formKey = "form"
|
||||
pathKey = "path"
|
||||
emptyJson = "{}"
|
||||
@ -23,21 +21,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBodylessRequest = errors.New("not a POST|PUT|PATCH request")
|
||||
|
||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||
)
|
||||
|
||||
// Returns the peer address, supports X-Forward-For
|
||||
func GetRemoteAddr(r *http.Request) string {
|
||||
v := r.Header.Get(xForwardFor)
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
func Parse(r *http.Request, v interface{}) error {
|
||||
if err := ParsePath(r, v); err != nil {
|
||||
return err
|
||||
@ -110,7 +97,7 @@ func ParseJsonBody(r *http.Request, v interface{}) error {
|
||||
// Parses the symbols reside in url path.
|
||||
// Like http://localhost/bag/:name
|
||||
func ParsePath(r *http.Request, v interface{}) error {
|
||||
vars := httprouter.Vars(r)
|
||||
vars := context.Vars(r)
|
||||
m := make(map[string]interface{}, len(vars))
|
||||
for k, v := range vars {
|
||||
m[k] = v
|
@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"zero/core/httprouter"
|
||||
"zero/ngin/internal/router"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -20,15 +20,6 @@ const (
|
||||
contentLength = "Content-Length"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
host := "8.8.8.8"
|
||||
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
|
||||
assert.Nil(t, err)
|
||||
|
||||
r.Header.Set(xForwardFor, host)
|
||||
assert.Equal(t, host, GetRemoteAddr(r))
|
||||
}
|
||||
|
||||
func TestParseForm(t *testing.T) {
|
||||
var v struct {
|
||||
Name string `form:"name"`
|
||||
@ -135,8 +126,8 @@ func TestParseSlice(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rt := router.NewPatRouter()
|
||||
err = rt.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
Names []string `form:"names"`
|
||||
}{}
|
||||
@ -150,7 +141,7 @@ func TestParseSlice(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, r)
|
||||
rt.ServeHTTP(rr, r)
|
||||
}
|
||||
|
||||
func TestParseJsonPost(t *testing.T) {
|
||||
@ -159,7 +150,7 @@ func TestParseJsonPost(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
|
||||
w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -191,7 +182,7 @@ func TestParseJsonPostWithIntSlice(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
|
||||
w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -219,7 +210,7 @@ func TestParseJsonPostError(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -247,7 +238,7 @@ func TestParseJsonPostInvalidRequest(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -269,7 +260,7 @@ func TestParseJsonPostRequired(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -292,7 +283,7 @@ func TestParsePath(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -317,7 +308,7 @@ func TestParsePathRequired(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -338,7 +329,7 @@ func TestParseQuery(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -363,7 +354,7 @@ func TestParseQueryRequired(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
Nickname string `form:"nickname"`
|
||||
@ -383,7 +374,7 @@ func TestParseOptional(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -424,7 +415,7 @@ func TestParseNestedInRequestEmpty(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -463,7 +454,7 @@ func TestParsePtrInRequest(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -494,7 +485,7 @@ func TestParsePtrInRequestEmpty(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -511,7 +502,7 @@ func TestParseQueryOptional(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -536,7 +527,7 @@ func TestParse(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -574,7 +565,7 @@ func TestParseWrappedRequest(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -606,7 +597,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -639,7 +630,7 @@ func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -671,7 +662,7 @@ func TestParseWrappedRequestPtr(t *testing.T) {
|
||||
}
|
||||
)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var v WrappedRequest
|
||||
@ -694,7 +685,7 @@ func TestParseWithAll(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
Name string `path:"name"`
|
||||
@ -725,7 +716,7 @@ func TestParseWithAllUtf8(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, applicationJsonWithUtf8)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -756,7 +747,7 @@ func TestParseWithMissingForm(t *testing.T) {
|
||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -783,7 +774,7 @@ func TestParseWithMissingAllForms(t *testing.T) {
|
||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -809,7 +800,7 @@ func TestParseWithMissingJson(t *testing.T) {
|
||||
bytes.NewBufferString(`{"location": "shanghai"}`))
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -835,7 +826,7 @@ func TestParseWithMissingAllJsons(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -862,7 +853,7 @@ func TestParseWithMissingPath(t *testing.T) {
|
||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -889,7 +880,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
|
||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -916,7 +907,7 @@ func TestParseGetWithContentLengthHeader(t *testing.T) {
|
||||
r.Header.Set(ContentType, ApplicationJson)
|
||||
r.Header.Set(contentLength, "1024")
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -943,7 +934,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, applicationJsonWithUtf8)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
@ -969,7 +960,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(ContentType, applicationJsonWithUtf8)
|
||||
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router.NewPatRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
21
ngin/internal/context/params.go
Normal file
21
ngin/internal/context/params.go
Normal file
@ -0,0 +1,21 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const pathVars = "pathVars"
|
||||
|
||||
func Vars(r *http.Request) map[string]string {
|
||||
vars, ok := r.Context().Value(pathVars).(map[string]string)
|
||||
if ok {
|
||||
return vars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithPathVars(r *http.Request, params map[string]string) *http.Request {
|
||||
return r.WithContext(context.WithValue(r.Context(), pathVars, params))
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package httplog
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"zero/core/httpx"
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
@ -80,5 +79,5 @@ func formatf(r *http.Request, format string, v ...interface{}) string {
|
||||
}
|
||||
|
||||
func formatWithReq(r *http.Request, v string) string {
|
||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v)
|
||||
return fmt.Sprintf("(%s - %s) %s", r.RequestURI, GetRemoteAddr(r), v)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package httplog
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,18 +1,17 @@
|
||||
package httprouter
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"zero/core/search"
|
||||
"zero/ngin/internal/context"
|
||||
)
|
||||
|
||||
const (
|
||||
allowHeader = "Allow"
|
||||
allowMethodSeparator = ", "
|
||||
pathVars = "pathVars"
|
||||
)
|
||||
|
||||
type PatRouter struct {
|
||||
@ -50,7 +49,7 @@ func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if tree, ok := pr.trees[r.Method]; ok {
|
||||
if result, ok := tree.Search(reqPath); ok {
|
||||
if len(result.Params) > 0 {
|
||||
r = r.WithContext(context.WithValue(r.Context(), pathVars, result.Params))
|
||||
r = context.WithPathVars(r, result.Params)
|
||||
}
|
||||
result.Item.(http.Handler).ServeHTTP(w, r)
|
||||
return
|
||||
@ -98,15 +97,6 @@ func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func Vars(r *http.Request) map[string]string {
|
||||
vars, ok := r.Context().Value(pathVars).(map[string]string)
|
||||
if ok {
|
||||
return vars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validMethod(method string) bool {
|
||||
return method == http.MethodDelete || method == http.MethodGet ||
|
||||
method == http.MethodHead || method == http.MethodOptions ||
|
@ -1,10 +1,12 @@
|
||||
package httprouter
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/ngin/internal/context"
|
||||
)
|
||||
|
||||
type mockedResponseWriter struct {
|
||||
@ -78,12 +80,12 @@ func TestPatRouter(t *testing.T) {
|
||||
router := NewPatRouter()
|
||||
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Equal(t, 1, len(Vars(r)))
|
||||
assert.Equal(t, 1, len(context.Vars(r)))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/a/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
routed = true
|
||||
assert.Nil(t, Vars(r))
|
||||
assert.Nil(t, context.Vars(r))
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
err = router.Handle(test.method, "/b/c", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
@ -1,4 +1,4 @@
|
||||
package httprouter
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@ -13,9 +13,9 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httpx"
|
||||
"zero/core/iox"
|
||||
"zero/core/logx"
|
||||
"zero/ngin/httpx"
|
||||
)
|
||||
|
||||
const (
|
@ -1,4 +1,4 @@
|
||||
package internal
|
||||
package security
|
||||
|
||||
import "net/http"
|
||||
|
@ -1,4 +1,4 @@
|
||||
package httpserver
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
@ -1,4 +1,4 @@
|
||||
package httpserver
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
@ -1,4 +1,4 @@
|
||||
package httpsecurity
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
@ -1,4 +1,4 @@
|
||||
package httpsecurity
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
14
ngin/internal/util.go
Normal file
14
ngin/internal/util.go
Normal file
@ -0,0 +1,14 @@
|
||||
package internal
|
||||
|
||||
import "net/http"
|
||||
|
||||
const xForwardFor = "X-Forward-For"
|
||||
|
||||
// Returns the peer address, supports X-Forward-For
|
||||
func GetRemoteAddr(r *http.Request) string {
|
||||
v := r.Header.Get(xForwardFor)
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
19
ngin/internal/util_test.go
Normal file
19
ngin/internal/util_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetRemoteAddr(t *testing.T) {
|
||||
host := "8.8.8.8"
|
||||
r, err := http.NewRequest(http.MethodGet, "/", strings.NewReader(""))
|
||||
assert.Nil(t, err)
|
||||
|
||||
r.Header.Set(xForwardFor, host)
|
||||
assert.Equal(t, host, GetRemoteAddr(r))
|
||||
}
|
||||
|
10
ngin/ngin.go
10
ngin/ngin.go
@ -4,9 +4,9 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"zero/core/httphandler"
|
||||
"zero/core/httprouter"
|
||||
"zero/core/logx"
|
||||
"zero/ngin/handler"
|
||||
"zero/ngin/internal/router"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -124,7 +124,7 @@ func WithPriority() RouteOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithRouter(router httprouter.Router) RunOption {
|
||||
func WithRouter(router router.Router) RunOption {
|
||||
return func(engine *Engine) {
|
||||
engine.opts.start = func(srv *server) error {
|
||||
return srv.StartWithRouter(router)
|
||||
@ -141,13 +141,13 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnauthorizedCallback(callback httphandler.UnauthorizedCallback) RunOption {
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(engine *Engine) {
|
||||
engine.srv.SetUnauthorizedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
func WithUnsignedCallback(callback httphandler.UnsignedCallback) RunOption {
|
||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||
return func(engine *Engine) {
|
||||
engine.srv.SetUnsignedCallback(callback)
|
||||
}
|
||||
|
@ -7,15 +7,15 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"zero/core/httprouter"
|
||||
"zero/core/httpx"
|
||||
"zero/ngin/httpx"
|
||||
router2 "zero/ngin/internal/router"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWithMiddleware(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
router := httprouter.NewPatRouter()
|
||||
router := router2.NewPatRouter()
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
var v struct {
|
||||
Nickname string `form:"nickname"`
|
||||
|
@ -7,11 +7,11 @@ import (
|
||||
"time"
|
||||
|
||||
"zero/core/codec"
|
||||
"zero/core/httphandler"
|
||||
"zero/core/httprouter"
|
||||
"zero/core/httpserver"
|
||||
"zero/core/load"
|
||||
"zero/core/stat"
|
||||
"zero/ngin/handler"
|
||||
"zero/ngin/internal"
|
||||
"zero/ngin/internal/router"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
)
|
||||
@ -27,8 +27,8 @@ type (
|
||||
server struct {
|
||||
conf NgConf
|
||||
routes []featuredRoutes
|
||||
unauthorizedCallback httphandler.UnauthorizedCallback
|
||||
unsignedCallback httphandler.UnsignedCallback
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
middlewares []Middleware
|
||||
shedder load.Shedder
|
||||
priorityShedder load.Shedder
|
||||
@ -52,43 +52,43 @@ func (s *server) AddRoutes(r featuredRoutes) {
|
||||
s.routes = append(s.routes, r)
|
||||
}
|
||||
|
||||
func (s *server) SetUnauthorizedCallback(callback httphandler.UnauthorizedCallback) {
|
||||
func (s *server) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||
s.unauthorizedCallback = callback
|
||||
}
|
||||
|
||||
func (s *server) SetUnsignedCallback(callback httphandler.UnsignedCallback) {
|
||||
func (s *server) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
s.unsignedCallback = callback
|
||||
}
|
||||
|
||||
func (s *server) Start() error {
|
||||
return s.StartWithRouter(httprouter.NewPatRouter())
|
||||
return s.StartWithRouter(router.NewPatRouter())
|
||||
}
|
||||
|
||||
func (s *server) StartWithRouter(router httprouter.Router) error {
|
||||
func (s *server) StartWithRouter(router router.Router) error {
|
||||
if err := s.bindRoutes(router); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return httpserver.StartHttp(s.conf.Host, s.conf.Port, router)
|
||||
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
|
||||
}
|
||||
|
||||
func (s *server) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||
if fr.jwt.enabled {
|
||||
if len(fr.jwt.prevSecret) == 0 {
|
||||
chain = chain.Append(httphandler.Authorize(fr.jwt.secret,
|
||||
httphandler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
} else {
|
||||
chain = chain.Append(httphandler.Authorize(fr.jwt.secret,
|
||||
httphandler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
httphandler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
}
|
||||
}
|
||||
|
||||
return verifier(chain)
|
||||
}
|
||||
|
||||
func (s *server) bindFeaturedRoutes(router httprouter.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
func (s *server) bindFeaturedRoutes(router router.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
verifier, err := s.signatureVerifier(fr.signature)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -103,20 +103,20 @@ func (s *server) bindFeaturedRoutes(router httprouter.Router, fr featuredRoutes,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) bindRoute(fr featuredRoutes, router httprouter.Router, metrics *stat.Metrics,
|
||||
func (s *server) bindRoute(fr featuredRoutes, router router.Router, metrics *stat.Metrics,
|
||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||
chain := alice.New(
|
||||
httphandler.TracingHandler,
|
||||
handler.TracingHandler,
|
||||
s.getLogHandler(),
|
||||
httphandler.MaxConns(s.conf.MaxConns),
|
||||
httphandler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
httphandler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
httphandler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
httphandler.RecoverHandler,
|
||||
httphandler.MetricHandler(metrics),
|
||||
httphandler.PromMetricHandler(route.Path),
|
||||
httphandler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
httphandler.GunzipHandler,
|
||||
handler.MaxConns(s.conf.MaxConns),
|
||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.PromMetricHandler(route.Path),
|
||||
handler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
chain = s.appendAuthHandler(fr, chain, verifier)
|
||||
|
||||
@ -128,7 +128,7 @@ func (s *server) bindRoute(fr featuredRoutes, router httprouter.Router, metrics
|
||||
return router.Handle(route.Method, route.Path, handle)
|
||||
}
|
||||
|
||||
func (s *server) bindRoutes(router httprouter.Router) error {
|
||||
func (s *server) bindRoutes(router router.Router) error {
|
||||
metrics := s.createMetrics()
|
||||
|
||||
for _, fr := range s.routes {
|
||||
@ -154,9 +154,9 @@ func (s *server) createMetrics() *stat.Metrics {
|
||||
|
||||
func (s *server) getLogHandler() func(http.Handler) http.Handler {
|
||||
if s.conf.Verbose {
|
||||
return httphandler.DetailedLogHandler
|
||||
return handler.DetailedLogHandler
|
||||
} else {
|
||||
return httphandler.LogHandler
|
||||
return handler.LogHandler
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,10 +198,10 @@ func (s *server) signatureVerifier(signature signatureSetting) (func(chain alice
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
if s.unsignedCallback != nil {
|
||||
return chain.Append(httphandler.ContentSecurityHandler(
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
|
||||
} else {
|
||||
return chain.Append(httphandler.ContentSecurityHandler(
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict))
|
||||
}
|
||||
}, nil
|
||||
|
Loading…
Reference in New Issue
Block a user