chore: refactor and add more tests (#3351)

This commit is contained in:
Kevin Wan 2023-06-16 01:04:58 +08:00 committed by GitHub
parent 1262266ac2
commit f998803131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 63 deletions

View File

@ -13,40 +13,24 @@ import (
) )
var ( var (
errorHandler func(error) (int, any) errorHandler func(context.Context, error) (int, any)
errorHandlerCtx func(context.Context, error) (int, any) errorLock sync.RWMutex
commonHandler func(any) any respHandler func(context.Context, any) any
commonHandlerCtx func(context.Context, any) any respLock sync.RWMutex
lock sync.RWMutex
cLock sync.RWMutex
) )
// Error writes err into w. // Error writes err into w.
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock() doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...)
handler := errorHandler
lock.RUnlock()
doHandleError(w, err, handler, WriteJson, fns...)
} }
// ErrorCtx writes err into w. // ErrorCtx writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
fns ...func(w http.ResponseWriter, err error)) { fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()
var handler func(error) (int, any)
if handlerCtx != nil {
handler = func(err error) (int, any) {
return handlerCtx(ctx, err)
}
}
writeJson := func(w http.ResponseWriter, code int, v any) { writeJson := func(w http.ResponseWriter, code int, v any) {
WriteJsonCtx(ctx, w, code, v) WriteJsonCtx(ctx, w, code, v)
} }
doHandleError(w, err, handler, writeJson, fns...) doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...)
} }
// Ok writes HTTP 200 OK into w. // Ok writes HTTP 200 OK into w.
@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) {
// OkJson writes v into w with 200 OK. // OkJson writes v into w with 200 OK.
func OkJson(w http.ResponseWriter, v any) { func OkJson(w http.ResponseWriter, v any) {
cLock.RLock() respLock.RLock()
handler := commonHandler handler := respHandler
cLock.RUnlock() respLock.RUnlock()
if handler != nil { if handler != nil {
v = handler(v) v = handler(context.Background(), v)
} }
WriteJson(w, http.StatusOK, v) WriteJson(w, http.StatusOK, v)
} }
// OkJsonCtx writes v into w with 200 OK. // OkJsonCtx writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) { func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
cLock.RLock() respLock.RLock()
handlerCtx := commonHandlerCtx handlerCtx := respHandler
cLock.RUnlock() respLock.RUnlock()
if handlerCtx != nil { if handlerCtx != nil {
v = handlerCtx(ctx, v) v = handlerCtx(ctx, v)
} }
@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) {
} }
// SetErrorHandler sets the error handler, which is called on calling Error. // SetErrorHandler sets the error handler, which is called on calling Error.
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
func SetErrorHandler(handler func(error) (int, any)) { func SetErrorHandler(handler func(error) (int, any)) {
lock.Lock() errorLock.Lock()
defer lock.Unlock() defer errorLock.Unlock()
errorHandler = handler errorHandler = func(_ context.Context, err error) (int, any) {
return handler(err)
}
} }
// SetErrorHandlerCtx sets the error handler, which is called on calling Error. // SetErrorHandlerCtx sets the error handler, which is called on calling Error.
// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler.
// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) { func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) {
lock.Lock() errorLock.Lock()
defer lock.Unlock() defer errorLock.Unlock()
errorHandlerCtx = handlerCtx errorHandler = handlerCtx
} }
// SetCommonHandler sets the common handler, which is called on calling OkJson. // SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx.
func SetCommonHandler(handler func(any) any) { func SetResponseHandler(handler func(context.Context, any) any) {
cLock.Lock() respLock.Lock()
defer cLock.Unlock() defer respLock.Unlock()
commonHandler = handler respHandler = handler
}
// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson.
func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) {
cLock.Lock()
defer cLock.Unlock()
commonHandlerCtx = handlerCtx
} }
// WriteJson writes v as json string into w with code. // WriteJson writes v as json string into w with code.
@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) {
} }
} }
func buildErrorHandler(ctx context.Context) func(error) (int, any) {
errorLock.RLock()
handlerCtx := errorHandler
errorLock.RUnlock()
var handler func(error) (int, any)
if handlerCtx != nil {
handler = func(err error) (int, any) {
return handlerCtx(ctx, err)
}
}
return handler
}
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any), func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any),
writeJson func(w http.ResponseWriter, code int, v any), writeJson func(w http.ResponseWriter, code int, v any),
fns ...func(w http.ResponseWriter, err error)) { fns ...func(w http.ResponseWriter, err error)) {

View File

@ -3,6 +3,7 @@ package httpx
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
@ -80,14 +81,14 @@ func TestError(t *testing.T) {
headers: make(map[string][]string), headers: make(map[string][]string),
} }
if test.errorHandler != nil { if test.errorHandler != nil {
lock.RLock() errorLock.RLock()
prev := errorHandler prev := errorHandler
lock.RUnlock() errorLock.RUnlock()
SetErrorHandler(test.errorHandler) SetErrorHandler(test.errorHandler)
defer func() { defer func() {
lock.Lock() errorLock.Lock()
errorHandler = prev errorHandler = prev
lock.Unlock() errorLock.Unlock()
}() }()
} }
Error(&w, errors.New(test.input)) Error(&w, errors.New(test.input))
@ -129,13 +130,71 @@ func TestOk(t *testing.T) {
} }
func TestOkJson(t *testing.T) { func TestOkJson(t *testing.T) {
w := tracedResponseWriter{ t.Run("no handler", func(t *testing.T) {
headers: make(map[string][]string), w := tracedResponseWriter{
} headers: make(map[string][]string),
msg := message{Name: "anyone"} }
OkJson(&w, msg) msg := message{Name: "anyone"}
assert.Equal(t, http.StatusOK, w.code) OkJson(&w, msg)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
})
t.Run("with handler", func(t *testing.T) {
respLock.RLock()
prev := respHandler
respLock.RUnlock()
t.Cleanup(func() {
respLock.Lock()
respHandler = prev
respLock.Unlock()
})
SetResponseHandler(func(_ context.Context, v interface{}) any {
return fmt.Sprintf("hello %s", v.(message).Name)
})
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJson(&w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, `"hello anyone"`, w.builder.String())
})
}
func TestOkJsonCtx(t *testing.T) {
t.Run("no handler", func(t *testing.T) {
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJsonCtx(context.Background(), &w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
})
t.Run("with handler", func(t *testing.T) {
respLock.RLock()
prev := respHandler
respLock.RUnlock()
t.Cleanup(func() {
respLock.Lock()
respHandler = prev
respLock.Unlock()
})
SetResponseHandler(func(_ context.Context, v interface{}) any {
return fmt.Sprintf("hello %s", v.(message).Name)
})
w := tracedResponseWriter{
headers: make(map[string][]string),
}
msg := message{Name: "anyone"}
OkJsonCtx(context.Background(), &w, msg)
assert.Equal(t, http.StatusOK, w.code)
assert.Equal(t, `"hello anyone"`, w.builder.String())
})
} }
func TestWriteJsonTimeout(t *testing.T) { func TestWriteJsonTimeout(t *testing.T) {
@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) {
headers: make(map[string][]string), headers: make(map[string][]string),
} }
if test.errorHandlerCtx != nil { if test.errorHandlerCtx != nil {
lock.RLock() errorLock.RLock()
prev := errorHandlerCtx prev := errorHandler
lock.RUnlock() errorLock.RUnlock()
SetErrorHandlerCtx(test.errorHandlerCtx) SetErrorHandlerCtx(test.errorHandlerCtx)
defer func() { defer func() {
lock.Lock() errorLock.Lock()
test.errorHandlerCtx = prev test.errorHandlerCtx = prev
lock.Unlock() errorLock.Unlock()
}() }()
} }
ErrorCtx(context.Background(), &w, errors.New(test.input)) ErrorCtx(context.Background(), &w, errors.New(test.input))