From 7a75dce465ec049b1d11205f12cd4f2ec5b69772 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Wed, 14 Dec 2022 23:36:56 +0800 Subject: [PATCH] refactor: remove duplicated code (#2705) --- rest/httpx/responses.go | 173 +++++++++++++++++++--------------------- 1 file changed, 80 insertions(+), 93 deletions(-) diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 9f2efa0a..fa6ca808 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -3,6 +3,7 @@ package httpx import ( "context" "encoding/json" + "fmt" "net/http" "sync" @@ -13,8 +14,8 @@ import ( var ( errorHandler func(error) (int, interface{}) - lock sync.RWMutex errorHandlerCtx func(context.Context, error) (int, interface{}) + lock sync.RWMutex ) // Error writes err into w. @@ -23,9 +24,79 @@ func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, handler := errorHandler lock.RUnlock() + doHandleError(w, err, handler, WriteJson, fns...) +} + +// ErrorCtx writes err into w. +func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, + fns ...func(w http.ResponseWriter, err error)) { + lock.RLock() + handlerCtx := errorHandlerCtx + lock.RUnlock() + + var handler func(error) (int, interface{}) + if handlerCtx != nil { + handler = func(err error) (int, interface{}) { + return handlerCtx(ctx, err) + } + } + writeJson := func(w http.ResponseWriter, code int, v interface{}) { + WriteJsonCtx(ctx, w, code, v) + } + doHandleError(w, err, handler, writeJson, fns...) +} + +// Ok writes HTTP 200 OK into w. +func Ok(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) +} + +// OkJson writes v into w with 200 OK. +func OkJson(w http.ResponseWriter, v interface{}) { + WriteJson(w, http.StatusOK, v) +} + +// OkJsonCtx writes v into w with 200 OK. +func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) { + WriteJsonCtx(ctx, w, http.StatusOK, v) +} + +// SetErrorHandler sets the error handler, which is called on calling Error. +func SetErrorHandler(handler func(error) (int, interface{})) { + lock.Lock() + defer lock.Unlock() + errorHandler = handler +} + +// SetErrorHandlerCtx sets the error handler, which is called on calling Error. +func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) { + lock.Lock() + defer lock.Unlock() + errorHandlerCtx = handlerCtx +} + +// WriteJson writes v as json string into w with code. +func WriteJson(w http.ResponseWriter, code int, v interface{}) { + if err := doWriteJson(w, code, v); err != nil { + logx.Error(err) + } +} + +// WriteJsonCtx writes v as json string into w with code. +func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) { + if err := doWriteJson(w, code, v); err != nil { + logx.WithContext(ctx).Error(err) + } +} + +func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, interface{}), + writeJson func(w http.ResponseWriter, code int, v interface{}), + fns ...func(w http.ResponseWriter, err error)) { if handler == nil { if len(fns) > 0 { - fns[0](w, err) + for _, fn := range fns { + fn(w, err) + } } else if errcode.IsGrpcError(err) { // don't unwrap error and get status.Message(), // it hides the rpc error headers. @@ -47,33 +118,15 @@ func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, if ok { http.Error(w, e.Error(), code) } else { - WriteJson(w, code, body) + writeJson(w, code, body) } } -// Ok writes HTTP 200 OK into w. -func Ok(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) -} - -// OkJson writes v into w with 200 OK. -func OkJson(w http.ResponseWriter, v interface{}) { - WriteJson(w, http.StatusOK, v) -} - -// SetErrorHandler sets the error handler, which is called on calling Error. -func SetErrorHandler(handler func(error) (int, interface{})) { - lock.Lock() - defer lock.Unlock() - errorHandler = handler -} - -// WriteJson writes v as json string into w with code. -func WriteJson(w http.ResponseWriter, code int, v interface{}) { +func doWriteJson(w http.ResponseWriter, code int, v interface{}) error { bs, err := json.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - return + return fmt.Errorf("marshal json failed, error: %w", err) } w.Header().Set(ContentType, header.JsonContentType) @@ -83,77 +136,11 @@ func WriteJson(w http.ResponseWriter, code int, v interface{}) { // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, // so it's ignored here. if err != http.ErrHandlerTimeout { - logx.Errorf("write response failed, error: %s", err) + return fmt.Errorf("write response failed, error: %w", err) } } else if n < len(bs) { - logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) + return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) } -} - -// Error writes err into w. -func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { - lock.RLock() - handlerCtx := errorHandlerCtx - lock.RUnlock() - - if handlerCtx == nil { - if len(fns) > 0 { - fns[0](w, err) - } else if errcode.IsGrpcError(err) { - // don't unwrap error and get status.Message(), - // it hides the rpc error headers. - http.Error(w, err.Error(), errcode.CodeFromGrpcError(err)) - } else { - http.Error(w, err.Error(), http.StatusBadRequest) - } - - return - } - - code, body := handlerCtx(ctx, err) - if body == nil { - w.WriteHeader(code) - return - } - - e, ok := body.(error) - if ok { - http.Error(w, e.Error(), code) - } else { - WriteJsonCtx(ctx, w, code, body) - } -} - -// OkJson writes v into w with 200 OK. -func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) { - WriteJsonCtx(ctx, w, http.StatusOK, v) -} - -// WriteJson writes v as json string into w with code. -func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) { - bs, err := json.Marshal(v) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set(ContentType, header.JsonContentType) - w.WriteHeader(code) - - if n, err := w.Write(bs); err != nil { - // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, - // so it's ignored here. - if err != http.ErrHandlerTimeout { - logx.WithContext(ctx).Errorf("write response failed, error: %s", err) - } - } else if n < len(bs) { - logx.WithContext(ctx).Errorf("actual bytes: %d, written bytes: %d", len(bs), n) - } -} - -// SetErrorHandler sets the error handler, which is called on calling Error. -func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) { - lock.Lock() - defer lock.Unlock() - errorHandlerCtx = handlerCtx + + return nil }