mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-01-23 09:00:20 +08:00
chore: refactor and add more tests (#3351)
This commit is contained in:
parent
1262266ac2
commit
f998803131
@ -13,40 +13,24 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errorHandler func(error) (int, any)
|
||||
errorHandlerCtx func(context.Context, error) (int, any)
|
||||
commonHandler func(any) any
|
||||
commonHandlerCtx func(context.Context, any) any
|
||||
lock sync.RWMutex
|
||||
cLock sync.RWMutex
|
||||
errorHandler func(context.Context, error) (int, any)
|
||||
errorLock sync.RWMutex
|
||||
respHandler func(context.Context, any) any
|
||||
respLock sync.RWMutex
|
||||
)
|
||||
|
||||
// Error writes err into w.
|
||||
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
|
||||
lock.RLock()
|
||||
handler := errorHandler
|
||||
lock.RUnlock()
|
||||
|
||||
doHandleError(w, err, handler, WriteJson, fns...)
|
||||
doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
|
||||
}
|
||||
|
||||
// ErrorCtx writes err into w.
|
||||
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
|
||||
fns ...func(w http.ResponseWriter, err error)) {
|
||||
lock.RLock()
|
||||
handlerCtx := errorHandlerCtx
|
||||
lock.RUnlock()
|
||||
|
||||
var handler func(error) (int, any)
|
||||
if handlerCtx != nil {
|
||||
handler = func(err error) (int, any) {
|
||||
return handlerCtx(ctx, err)
|
||||
}
|
||||
}
|
||||
writeJson := func(w http.ResponseWriter, code int, v any) {
|
||||
WriteJsonCtx(ctx, w, code, v)
|
||||
}
|
||||
doHandleError(w, err, handler, writeJson, fns...)
|
||||
doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
|
||||
}
|
||||
|
||||
// Ok writes HTTP 200 OK into w.
|
||||
@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) {
|
||||
|
||||
// OkJson writes v into w with 200 OK.
|
||||
func OkJson(w http.ResponseWriter, v any) {
|
||||
cLock.RLock()
|
||||
handler := commonHandler
|
||||
cLock.RUnlock()
|
||||
respLock.RLock()
|
||||
handler := respHandler
|
||||
respLock.RUnlock()
|
||||
if handler != nil {
|
||||
v = handler(v)
|
||||
v = handler(context.Background(), v)
|
||||
}
|
||||
WriteJson(w, http.StatusOK, v)
|
||||
}
|
||||
|
||||
// OkJsonCtx writes v into w with 200 OK.
|
||||
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
|
||||
cLock.RLock()
|
||||
handlerCtx := commonHandlerCtx
|
||||
cLock.RUnlock()
|
||||
respLock.RLock()
|
||||
handlerCtx := respHandler
|
||||
respLock.RUnlock()
|
||||
if handlerCtx != nil {
|
||||
v = handlerCtx(ctx, v)
|
||||
}
|
||||
@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
|
||||
}
|
||||
|
||||
// SetErrorHandler sets the error handler, which is called on calling Error.
|
||||
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
|
||||
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
|
||||
func SetErrorHandler(handler func(error) (int, any)) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
errorHandler = handler
|
||||
errorLock.Lock()
|
||||
defer errorLock.Unlock()
|
||||
errorHandler = func(_ context.Context, err error) (int, any) {
|
||||
return handler(err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
|
||||
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
|
||||
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
|
||||
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
errorHandlerCtx = handlerCtx
|
||||
errorLock.Lock()
|
||||
defer errorLock.Unlock()
|
||||
errorHandler = handlerCtx
|
||||
}
|
||||
|
||||
// SetCommonHandler sets the common handler, which is called on calling OkJson.
|
||||
func SetCommonHandler(handler func(any) any) {
|
||||
cLock.Lock()
|
||||
defer cLock.Unlock()
|
||||
commonHandler = handler
|
||||
}
|
||||
|
||||
// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson.
|
||||
func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) {
|
||||
cLock.Lock()
|
||||
defer cLock.Unlock()
|
||||
commonHandlerCtx = handlerCtx
|
||||
// SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
|
||||
func SetResponseHandler(handler func(context.Context, any) any) {
|
||||
respLock.Lock()
|
||||
defer respLock.Unlock()
|
||||
respHandler = handler
|
||||
}
|
||||
|
||||
// WriteJson writes v as json string into w with code.
|
||||
@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildErrorHandler(ctx context.Context) func(error) (int, any) {
|
||||
errorLock.RLock()
|
||||
handlerCtx := errorHandler
|
||||
errorLock.RUnlock()
|
||||
|
||||
var handler func(error) (int, any)
|
||||
if handlerCtx != nil {
|
||||
handler = func(err error) (int, any) {
|
||||
return handlerCtx(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
|
||||
writeJson func(w http.ResponseWriter, code int, v any),
|
||||
fns ...func(w http.ResponseWriter, err error)) {
|
||||
|
@ -3,6 +3,7 @@ package httpx
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -80,14 +81,14 @@ func TestError(t *testing.T) {
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
if test.errorHandler != nil {
|
||||
lock.RLock()
|
||||
errorLock.RLock()
|
||||
prev := errorHandler
|
||||
lock.RUnlock()
|
||||
errorLock.RUnlock()
|
||||
SetErrorHandler(test.errorHandler)
|
||||
defer func() {
|
||||
lock.Lock()
|
||||
errorLock.Lock()
|
||||
errorHandler = prev
|
||||
lock.Unlock()
|
||||
errorLock.Unlock()
|
||||
}()
|
||||
}
|
||||
Error(&w, errors.New(test.input))
|
||||
@ -129,6 +130,7 @@ func TestOk(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOkJson(t *testing.T) {
|
||||
t.Run("no handler", func(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
@ -136,6 +138,63 @@ func TestOkJson(t *testing.T) {
|
||||
OkJson(&w, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
||||
})
|
||||
|
||||
t.Run("with handler", func(t *testing.T) {
|
||||
respLock.RLock()
|
||||
prev := respHandler
|
||||
respLock.RUnlock()
|
||||
t.Cleanup(func() {
|
||||
respLock.Lock()
|
||||
respHandler = prev
|
||||
respLock.Unlock()
|
||||
})
|
||||
|
||||
SetResponseHandler(func(_ context.Context, v interface{}) any {
|
||||
return fmt.Sprintf("hello %s", v.(message).Name)
|
||||
})
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
OkJson(&w, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
assert.Equal(t, `"hello anyone"`, w.builder.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOkJsonCtx(t *testing.T) {
|
||||
t.Run("no handler", func(t *testing.T) {
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
OkJsonCtx(context.Background(), &w, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
||||
})
|
||||
|
||||
t.Run("with handler", func(t *testing.T) {
|
||||
respLock.RLock()
|
||||
prev := respHandler
|
||||
respLock.RUnlock()
|
||||
t.Cleanup(func() {
|
||||
respLock.Lock()
|
||||
respHandler = prev
|
||||
respLock.Unlock()
|
||||
})
|
||||
|
||||
SetResponseHandler(func(_ context.Context, v interface{}) any {
|
||||
return fmt.Sprintf("hello %s", v.(message).Name)
|
||||
})
|
||||
w := tracedResponseWriter{
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
msg := message{Name: "anyone"}
|
||||
OkJsonCtx(context.Background(), &w, msg)
|
||||
assert.Equal(t, http.StatusOK, w.code)
|
||||
assert.Equal(t, `"hello anyone"`, w.builder.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteJsonTimeout(t *testing.T) {
|
||||
@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) {
|
||||
headers: make(map[string][]string),
|
||||
}
|
||||
if test.errorHandlerCtx != nil {
|
||||
lock.RLock()
|
||||
prev := errorHandlerCtx
|
||||
lock.RUnlock()
|
||||
errorLock.RLock()
|
||||
prev := errorHandler
|
||||
errorLock.RUnlock()
|
||||
SetErrorHandlerCtx(test.errorHandlerCtx)
|
||||
defer func() {
|
||||
lock.Lock()
|
||||
errorLock.Lock()
|
||||
test.errorHandlerCtx = prev
|
||||
lock.Unlock()
|
||||
errorLock.Unlock()
|
||||
}()
|
||||
}
|
||||
ErrorCtx(context.Background(), &w, errors.New(test.input))
|
||||
|
Loading…
Reference in New Issue
Block a user