diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 29f54a5a..104f04aa 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -13,40 +13,24 @@ import ( ) var ( - errorHandler func(error) (int, any) - errorHandlerCtx func(context.Context, error) (int, any) - commonHandler func(any) any - commonHandlerCtx func(context.Context, any) any - lock sync.RWMutex - cLock sync.RWMutex + errorHandler func(context.Context, error) (int, any) + errorLock sync.RWMutex + respHandler func(context.Context, any) any + respLock sync.RWMutex ) // Error writes err into w. func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { - lock.RLock() - handler := errorHandler - lock.RUnlock() - - doHandleError(w, err, handler, WriteJson, fns...) + doHandleError(w, err, buildErrorHandler(context.Background()), WriteJson, fns...) } // ErrorCtx writes err into w. func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { - lock.RLock() - handlerCtx := errorHandlerCtx - lock.RUnlock() - - var handler func(error) (int, any) - if handlerCtx != nil { - handler = func(err error) (int, any) { - return handlerCtx(ctx, err) - } - } writeJson := func(w http.ResponseWriter, code int, v any) { WriteJsonCtx(ctx, w, code, v) } - doHandleError(w, err, handler, writeJson, fns...) + doHandleError(w, err, buildErrorHandler(ctx), writeJson, fns...) } // Ok writes HTTP 200 OK into w. @@ -56,20 +40,20 @@ func Ok(w http.ResponseWriter) { // OkJson writes v into w with 200 OK. func OkJson(w http.ResponseWriter, v any) { - cLock.RLock() - handler := commonHandler - cLock.RUnlock() + respLock.RLock() + handler := respHandler + respLock.RUnlock() if handler != nil { - v = handler(v) + v = handler(context.Background(), v) } WriteJson(w, http.StatusOK, v) } // OkJsonCtx writes v into w with 200 OK. func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) { - cLock.RLock() - handlerCtx := commonHandlerCtx - cLock.RUnlock() + respLock.RLock() + handlerCtx := respHandler + respLock.RUnlock() if handlerCtx != nil { v = handlerCtx(ctx, v) } @@ -77,31 +61,30 @@ func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v any) { } // SetErrorHandler sets the error handler, which is called on calling Error. +// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler. +// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility. func SetErrorHandler(handler func(error) (int, any)) { - lock.Lock() - defer lock.Unlock() - errorHandler = handler + errorLock.Lock() + defer errorLock.Unlock() + errorHandler = func(_ context.Context, err error) (int, any) { + return handler(err) + } } // SetErrorHandlerCtx sets the error handler, which is called on calling Error. +// Notice: SetErrorHandler and SetErrorHandlerCtx set the same error handler. +// Keeping both SetErrorHandler and SetErrorHandlerCtx is for backward compatibility. func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, any)) { - lock.Lock() - defer lock.Unlock() - errorHandlerCtx = handlerCtx + errorLock.Lock() + defer errorLock.Unlock() + errorHandler = handlerCtx } -// SetCommonHandler sets the common handler, which is called on calling OkJson. -func SetCommonHandler(handler func(any) any) { - cLock.Lock() - defer cLock.Unlock() - commonHandler = handler -} - -// SetCommonHandlerCtx sets the common handler, which is called on calling OkJson. -func SetCommonHandlerCtx(handlerCtx func(context.Context, any) any) { - cLock.Lock() - defer cLock.Unlock() - commonHandlerCtx = handlerCtx +// SetResponseHandler sets the response handler, which is called on calling OkJson and OkJsonCtx. +func SetResponseHandler(handler func(context.Context, any) any) { + respLock.Lock() + defer respLock.Unlock() + respHandler = handler } // WriteJson writes v as json string into w with code. @@ -118,6 +101,21 @@ func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v any) { } } +func buildErrorHandler(ctx context.Context) func(error) (int, any) { + errorLock.RLock() + handlerCtx := errorHandler + errorLock.RUnlock() + + var handler func(error) (int, any) + if handlerCtx != nil { + handler = func(err error) (int, any) { + return handlerCtx(ctx, err) + } + } + + return handler +} + func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, any), writeJson func(w http.ResponseWriter, code int, v any), fns ...func(w http.ResponseWriter, err error)) { diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index a881dbad..4ef823ca 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -3,6 +3,7 @@ package httpx import ( "context" "errors" + "fmt" "net/http" "strings" "testing" @@ -80,14 +81,14 @@ func TestError(t *testing.T) { headers: make(map[string][]string), } if test.errorHandler != nil { - lock.RLock() + errorLock.RLock() prev := errorHandler - lock.RUnlock() + errorLock.RUnlock() SetErrorHandler(test.errorHandler) defer func() { - lock.Lock() + errorLock.Lock() errorHandler = prev - lock.Unlock() + errorLock.Unlock() }() } Error(&w, errors.New(test.input)) @@ -129,13 +130,71 @@ func TestOk(t *testing.T) { } func TestOkJson(t *testing.T) { - w := tracedResponseWriter{ - headers: make(map[string][]string), - } - msg := message{Name: "anyone"} - OkJson(&w, msg) - assert.Equal(t, http.StatusOK, w.code) - assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) + t.Run("no handler", func(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + msg := message{Name: "anyone"} + OkJson(&w, msg) + assert.Equal(t, http.StatusOK, w.code) + assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) + }) + + t.Run("with handler", func(t *testing.T) { + respLock.RLock() + prev := respHandler + respLock.RUnlock() + t.Cleanup(func() { + respLock.Lock() + respHandler = prev + respLock.Unlock() + }) + + SetResponseHandler(func(_ context.Context, v interface{}) any { + return fmt.Sprintf("hello %s", v.(message).Name) + }) + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + msg := message{Name: "anyone"} + OkJson(&w, msg) + assert.Equal(t, http.StatusOK, w.code) + assert.Equal(t, `"hello anyone"`, w.builder.String()) + }) +} + +func TestOkJsonCtx(t *testing.T) { + t.Run("no handler", func(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + msg := message{Name: "anyone"} + OkJsonCtx(context.Background(), &w, msg) + assert.Equal(t, http.StatusOK, w.code) + assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String()) + }) + + t.Run("with handler", func(t *testing.T) { + respLock.RLock() + prev := respHandler + respLock.RUnlock() + t.Cleanup(func() { + respLock.Lock() + respHandler = prev + respLock.Unlock() + }) + + SetResponseHandler(func(_ context.Context, v interface{}) any { + return fmt.Sprintf("hello %s", v.(message).Name) + }) + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + msg := message{Name: "anyone"} + OkJsonCtx(context.Background(), &w, msg) + assert.Equal(t, http.StatusOK, w.code) + assert.Equal(t, `"hello anyone"`, w.builder.String()) + }) } func TestWriteJsonTimeout(t *testing.T) { @@ -275,14 +334,14 @@ func TestErrorCtx(t *testing.T) { headers: make(map[string][]string), } if test.errorHandlerCtx != nil { - lock.RLock() - prev := errorHandlerCtx - lock.RUnlock() + errorLock.RLock() + prev := errorHandler + errorLock.RUnlock() SetErrorHandlerCtx(test.errorHandlerCtx) defer func() { - lock.Lock() + errorLock.Lock() test.errorHandlerCtx = prev - lock.Unlock() + errorLock.Unlock() }() } ErrorCtx(context.Background(), &w, errors.New(test.input))