mirror of
https://github.com/zeromicro/go-zero.git
synced 2025-02-02 16:28:39 +08:00
chore: refactor and add more tests (#3351)
This commit is contained in:
parent
1262266ac2
commit
f998803131
@ -13,40 +13,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errorHandler func(error) (int, any)
|
errorHandler func(context.Context, error) (int, any)
|
||||||
errorHandlerCtx func(context.Context, error) (int, any)
|
errorLock sync.RWMutex
|
||||||
commonHandler func(any) any
|
respHandler func(context.Context, any) any
|
||||||
commonHandlerCtx func(context.Context, any) any
|
respLock sync.RWMutex
|
||||||
lock sync.RWMutex
|
|
||||||
cLock sync.RWMutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Error writes err into w.
|
// Error writes err into w.
|
||||||
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
|
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
|
||||||
lock.RLock()
|
doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
|
||||||
handler := errorHandler
|
|
||||||
lock.RUnlock()
|
|
||||||
|
|
||||||
doHandleError(w, err, handler, WriteJson, fns...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorCtx writes err into w.
|
// ErrorCtx writes err into w.
|
||||||
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
|
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
|
||||||
fns ...func(w http.ResponseWriter, err error)) {
|
fns ...func(w http.ResponseWriter, err error)) {
|
||||||
lock.RLock()
|
|
||||||
handlerCtx := errorHandlerCtx
|
|
||||||
lock.RUnlock()
|
|
||||||
|
|
||||||
var handler func(error) (int, any)
|
|
||||||
if handlerCtx != nil {
|
|
||||||
handler = func(err error) (int, any) {
|
|
||||||
return handlerCtx(ctx, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
writeJson := func(w http.ResponseWriter, code int, v any) {
|
writeJson := func(w http.ResponseWriter, code int, v any) {
|
||||||
WriteJsonCtx(ctx, w, code, v)
|
WriteJsonCtx(ctx, w, code, v)
|
||||||
}
|
}
|
||||||
doHandleError(w, err, handler, writeJson, fns...)
|
doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ok writes HTTP 200 OK into w.
|
// Ok writes HTTP 200 OK into w.
|
||||||
@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) {
|
|||||||
|
|
||||||
// OkJson writes v into w with 200 OK.
|
// OkJson writes v into w with 200 OK.
|
||||||
func OkJson(w http.ResponseWriter, v any) {
|
func OkJson(w http.ResponseWriter, v any) {
|
||||||
cLock.RLock()
|
respLock.RLock()
|
||||||
handler := commonHandler
|
handler := respHandler
|
||||||
cLock.RUnlock()
|
respLock.RUnlock()
|
||||||
if handler != nil {
|
if handler != nil {
|
||||||
v = handler(v)
|
v = handler(context.Background(), v)
|
||||||
}
|
}
|
||||||
WriteJson(w, http.StatusOK, v)
|
WriteJson(w, http.StatusOK, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OkJsonCtx writes v into w with 200 OK.
|
// OkJsonCtx writes v into w with 200 OK.
|
||||||
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
|
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
|
||||||
cLock.RLock()
|
respLock.RLock()
|
||||||
handlerCtx := commonHandlerCtx
|
handlerCtx := respHandler
|
||||||
cLock.RUnlock()
|
respLock.RUnlock()
|
||||||
if handlerCtx != nil {
|
if handlerCtx != nil {
|
||||||
v = handlerCtx(ctx, v)
|
v = handlerCtx(ctx, v)
|
||||||
}
|
}
|
||||||
@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetErrorHandler sets the error handler, which is called on calling Error.
|
// SetErrorHandler sets the error handler, which is called on calling Error.
|
||||||
|
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
|
||||||
|
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
|
||||||
func SetErrorHandler(handler func(error) (int, any)) {
|
func SetErrorHandler(handler func(error) (int, any)) {
|
||||||
lock.Lock()
|
errorLock.Lock()
|
||||||
defer lock.Unlock()
|
defer errorLock.Unlock()
|
||||||
errorHandler = handler
|
errorHandler = func(_ context.Context, err error) (int, any) {
|
||||||
|
return handler(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
|
// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
|
||||||
|
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
|
||||||
|
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
|
||||||
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
|
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
|
||||||
lock.Lock()
|
errorLock.Lock()
|
||||||
defer lock.Unlock()
|
defer errorLock.Unlock()
|
||||||
errorHandlerCtx = handlerCtx
|
errorHandler = handlerCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCommonHandler sets the common handler, which is called on calling OkJson.
|
// SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
|
||||||
func SetCommonHandler(handler func(any) any) {
|
func SetResponseHandler(handler func(context.Context, any) any) {
|
||||||
cLock.Lock()
|
respLock.Lock()
|
||||||
defer cLock.Unlock()
|
defer respLock.Unlock()
|
||||||
commonHandler = handler
|
respHandler = handler
|
||||||
}
|
|
||||||
|
|
||||||
// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson.
|
|
||||||
func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) {
|
|
||||||
cLock.Lock()
|
|
||||||
defer cLock.Unlock()
|
|
||||||
commonHandlerCtx = handlerCtx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteJson writes v as json string into w with code.
|
// WriteJson writes v as json string into w with code.
|
||||||
@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildErrorHandler(ctx context.Context) func(error) (int, any) {
|
||||||
|
errorLock.RLock()
|
||||||
|
handlerCtx := errorHandler
|
||||||
|
errorLock.RUnlock()
|
||||||
|
|
||||||
|
var handler func(error) (int, any)
|
||||||
|
if handlerCtx != nil {
|
||||||
|
handler = func(err error) (int, any) {
|
||||||
|
return handlerCtx(ctx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
|
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
|
||||||
writeJson func(w http.ResponseWriter, code int, v any),
|
writeJson func(w http.ResponseWriter, code int, v any),
|
||||||
fns ...func(w http.ResponseWriter, err error)) {
|
fns ...func(w http.ResponseWriter, err error)) {
|
||||||
|
@ -3,6 +3,7 @@ package httpx
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -80,14 +81,14 @@ func TestError(t *testing.T) {
|
|||||||
headers: make(map[string][]string),
|
headers: make(map[string][]string),
|
||||||
}
|
}
|
||||||
if test.errorHandler != nil {
|
if test.errorHandler != nil {
|
||||||
lock.RLock()
|
errorLock.RLock()
|
||||||
prev := errorHandler
|
prev := errorHandler
|
||||||
lock.RUnlock()
|
errorLock.RUnlock()
|
||||||
SetErrorHandler(test.errorHandler)
|
SetErrorHandler(test.errorHandler)
|
||||||
defer func() {
|
defer func() {
|
||||||
lock.Lock()
|
errorLock.Lock()
|
||||||
errorHandler = prev
|
errorHandler = prev
|
||||||
lock.Unlock()
|
errorLock.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
Error(&w, errors.New(test.input))
|
Error(&w, errors.New(test.input))
|
||||||
@ -129,13 +130,71 @@ func TestOk(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOkJson(t *testing.T) {
|
func TestOkJson(t *testing.T) {
|
||||||
w := tracedResponseWriter{
|
t.Run("no handler", func(t *testing.T) {
|
||||||
headers: make(map[string][]string),
|
w := tracedResponseWriter{
|
||||||
}
|
headers: make(map[string][]string),
|
||||||
msg := message{Name: "anyone"}
|
}
|
||||||
OkJson(&w, msg)
|
msg := message{Name: "anyone"}
|
||||||
assert.Equal(t, http.StatusOK, w.code)
|
OkJson(&w, msg)
|
||||||
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
assert.Equal(t, http.StatusOK, w.code)
|
||||||
|
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with handler", func(t *testing.T) {
|
||||||
|
respLock.RLock()
|
||||||
|
prev := respHandler
|
||||||
|
respLock.RUnlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
respLock.Lock()
|
||||||
|
respHandler = prev
|
||||||
|
respLock.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
SetResponseHandler(func(_ context.Context, v interface{}) any {
|
||||||
|
return fmt.Sprintf("hello %s", v.(message).Name)
|
||||||
|
})
|
||||||
|
w := tracedResponseWriter{
|
||||||
|
headers: make(map[string][]string),
|
||||||
|
}
|
||||||
|
msg := message{Name: "anyone"}
|
||||||
|
OkJson(&w, msg)
|
||||||
|
assert.Equal(t, http.StatusOK, w.code)
|
||||||
|
assert.Equal(t, `"hello anyone"`, w.builder.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOkJsonCtx(t *testing.T) {
|
||||||
|
t.Run("no handler", func(t *testing.T) {
|
||||||
|
w := tracedResponseWriter{
|
||||||
|
headers: make(map[string][]string),
|
||||||
|
}
|
||||||
|
msg := message{Name: "anyone"}
|
||||||
|
OkJsonCtx(context.Background(), &w, msg)
|
||||||
|
assert.Equal(t, http.StatusOK, w.code)
|
||||||
|
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with handler", func(t *testing.T) {
|
||||||
|
respLock.RLock()
|
||||||
|
prev := respHandler
|
||||||
|
respLock.RUnlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
respLock.Lock()
|
||||||
|
respHandler = prev
|
||||||
|
respLock.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
SetResponseHandler(func(_ context.Context, v interface{}) any {
|
||||||
|
return fmt.Sprintf("hello %s", v.(message).Name)
|
||||||
|
})
|
||||||
|
w := tracedResponseWriter{
|
||||||
|
headers: make(map[string][]string),
|
||||||
|
}
|
||||||
|
msg := message{Name: "anyone"}
|
||||||
|
OkJsonCtx(context.Background(), &w, msg)
|
||||||
|
assert.Equal(t, http.StatusOK, w.code)
|
||||||
|
assert.Equal(t, `"hello anyone"`, w.builder.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteJsonTimeout(t *testing.T) {
|
func TestWriteJsonTimeout(t *testing.T) {
|
||||||
@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) {
|
|||||||
headers: make(map[string][]string),
|
headers: make(map[string][]string),
|
||||||
}
|
}
|
||||||
if test.errorHandlerCtx != nil {
|
if test.errorHandlerCtx != nil {
|
||||||
lock.RLock()
|
errorLock.RLock()
|
||||||
prev := errorHandlerCtx
|
prev := errorHandler
|
||||||
lock.RUnlock()
|
errorLock.RUnlock()
|
||||||
SetErrorHandlerCtx(test.errorHandlerCtx)
|
SetErrorHandlerCtx(test.errorHandlerCtx)
|
||||||
defer func() {
|
defer func() {
|
||||||
lock.Lock()
|
errorLock.Lock()
|
||||||
test.errorHandlerCtx = prev
|
test.errorHandlerCtx = prev
|
||||||
lock.Unlock()
|
errorLock.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
ErrorCtx(context.Background(), &w, errors.New(test.input))
|
ErrorCtx(context.Background(), &w, errors.New(test.input))
|
||||||
|
Loading…
Reference in New Issue
Block a user