chore: refactor and add more tests (#3351)

This commit is contained in:
Kevin Wan 2023-06-16 01:04:58 +08:00 committed by GitHub
parent 1262266ac2
commit f998803131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 63 deletions

View File

@ -13,40 +13,24 @@ import (
)
var (
errorHandler func(error) (int, any)
errorHandlerCtx func(context.Context, error) (int, any)
commonHandler func(any) any
commonHandlerCtx func(context.Context, any) any
lock sync.RWMutex
cLock sync.RWMutex
errorHandler func(context.Context, error) (int, any)
errorLock sync.RWMutex
respHandler func(context.Context, any) any
respLock sync.RWMutex
)
// Error writes err into w.
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handler := errorHandler
lock.RUnlock()
doHandleError(w, err, handler, WriteJson, fns...)
doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
}
// ErrorCtx writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()
var handler func(error) (int, any)
if handlerCtx != nil {
handler = func(err error) (int, any) {
return handlerCtx(ctx, err)
}
}
writeJson := func(w http.ResponseWriter, code int, v any) {
WriteJsonCtx(ctx, w, code, v)
}
doHandleError(w, err, handler, writeJson, fns...)
doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
}
// Ok writes HTTP 200 OK into w.
@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) {
// OkJson writes v into w with 200 OK.
func OkJson(w http.ResponseWriter, v any) {
cLock.RLock()
handler := commonHandler
cLock.RUnlock()
respLock.RLock()
handler := respHandler
respLock.RUnlock()
if handler != nil {
v = handler(v)
v = handler(context.Background(), v)
}
WriteJson(w, http.StatusOK, v)
}
// OkJsonCtx writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
cLock.RLock()
handlerCtx := commonHandlerCtx
cLock.RUnlock()
respLock.RLock()
handlerCtx := respHandler
respLock.RUnlock()
if handlerCtx != nil {
v = handlerCtx(ctx, v)
}
@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
}
// SetErrorHandler sets the error handler, which is called on calling Error.
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
func SetErrorHandler(handler func(error) (int, any)) {
lock.Lock()
defer lock.Unlock()
errorHandler = handler
errorLock.Lock()
defer errorLock.Unlock()
errorHandler = func(_ context.Context, err error) (int, any) {
return handler(err)
}
}
// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
lock.Lock()
defer lock.Unlock()
errorHandlerCtx = handlerCtx
errorLock.Lock()
defer errorLock.Unlock()
errorHandler = handlerCtx
}
// SetCommonHandler sets the common handler, which is called on calling OkJson.
func SetCommonHandler(handler func(any) any) {
cLock.Lock()
defer cLock.Unlock()
commonHandler = handler
}
// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson.
func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) {
cLock.Lock()
defer cLock.Unlock()
commonHandlerCtx = handlerCtx
// SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
func SetResponseHandler(handler func(context.Context, any) any) {
respLock.Lock()
defer respLock.Unlock()
respHandler = handler
}
// WriteJson writes v as json string into w with code.
@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
}
}
func buildErrorHandler(ctx context.Context) func(error) (int, any) {
errorLock.RLock()
handlerCtx := errorHandler
errorLock.RUnlock()
var handler func(error) (int, any)
if handlerCtx != nil {
handler = func(err error) (int, any) {
return handlerCtx(ctx, err)
}
}
return handler
}
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
writeJson func(w http.ResponseWriter, code int, v any),
fns ...func(w http.ResponseWriter, err error)) {

View File

@ -3,6 +3,7 @@ package httpx
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"testing"
@ -80,14 +81,14 @@ func TestError(t *testing.T) {
headers: make(map[string][]string),
}
if test.errorHandler != nil {
lock.RLock()
errorLock.RLock()
prev := errorHandler
lock.RUnlock()
errorLock.RUnlock()
SetErrorHandler(test.errorHandler)
defer func() {
lock.Lock()
errorLock.Lock()
errorHandler = prev
lock.Unlock()
errorLock.Unlock()
}()
}
Error(&w, errors.New(test.input))
@ -129,6 +130,7 @@ func TestOk(t *testing.T) {
}
func TestOkJson(t *testing.T) {
t.Run("no handler", func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
@ -136,6 +138,63 @@ func TestOkJson(t *testing.T) {
OkJson(&w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
})
t.Run("with handler", func(t *testing.T) {
respLock.RLock()
prev := respHandler
respLock.RUnlock()
t.Cleanup(func() {
respLock.Lock()
respHandler = prev
respLock.Unlock()
})
SetResponseHandler(func(_ context.Context, v interface{}) any {
return fmt.Sprintf("hello %s", v.(message).Name)
})
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJson(&w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, `"hello anyone"`, w.builder.String())
})
}
func TestOkJsonCtx(t *testing.T) {
t.Run("no handler", func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJsonCtx(context.Background(), &w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
})
t.Run("with handler", func(t *testing.T) {
respLock.RLock()
prev := respHandler
respLock.RUnlock()
t.Cleanup(func() {
respLock.Lock()
respHandler = prev
respLock.Unlock()
})
SetResponseHandler(func(_ context.Context, v interface{}) any {
return fmt.Sprintf("hello %s", v.(message).Name)
})
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJsonCtx(context.Background(), &w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, `"hello anyone"`, w.builder.String())
})
}
func TestWriteJsonTimeout(t *testing.T) {
@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) {
headers: make(map[string][]string),
}
if test.errorHandlerCtx != nil {
lock.RLock()
prev := errorHandlerCtx
lock.RUnlock()
errorLock.RLock()
prev := errorHandler
errorLock.RUnlock()
SetErrorHandlerCtx(test.errorHandlerCtx)
defer func() {
lock.Lock()
errorLock.Lock()
test.errorHandlerCtx = prev
lock.Unlock()
errorLock.Unlock()
}()
}
ErrorCtx(context.Background(), &w, errors.New(test.input))