From 0bd2a0656c290a6fc227442820e33204b980de65 Mon Sep 17 00:00:00 2001 From: jichangyun Date: Mon, 28 Dec 2020 21:30:24 +0800 Subject: [PATCH] The ResponseWriters defined in rest.handler add Flush interface. (#318) --- rest/handler/authhandler.go | 6 ++++ rest/handler/authhandler_test.go | 4 +++ rest/handler/cryptionhandler.go | 6 ++++ rest/handler/cryptionhandler_test.go | 17 ++++++++++ rest/handler/loghandler.go | 12 +++++++ rest/handler/loghandler_test.go | 4 +++ .../security/withcoderesponsewriter.go | 6 ++++ .../security/withcoderesponsewriter_test.go | 33 +++++++++++++++++++ 8 files changed, 88 insertions(+) create mode 100644 rest/internal/security/withcoderesponsewriter_test.go diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index 5ed6caf1..ab65bd9b 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -138,3 +138,9 @@ func (grw *guardedResponseWriter) WriteHeader(statusCode int) { grw.wroteHeader = true grw.writer.WriteHeader(statusCode) } + +func (grw *guardedResponseWriter) Flush() { + if flusher, ok := grw.writer.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/rest/handler/authhandler_test.go b/rest/handler/authhandler_test.go index c197218f..1bd22649 100644 --- a/rest/handler/authhandler_test.go +++ b/rest/handler/authhandler_test.go @@ -41,6 +41,10 @@ func TestAuthHandler(t *testing.T) { w.Header().Set("X-Test", "test") _, err := w.Write([]byte("content")) assert.Nil(t, err) + + flusher, ok := w.(http.Flusher) + assert.Equal(t, ok, true) + flusher.Flush() })) resp := httptest.NewRecorder() diff --git a/rest/handler/cryptionhandler.go b/rest/handler/cryptionhandler.go index efa73bcf..58ee396c 100644 --- a/rest/handler/cryptionhandler.go +++ b/rest/handler/cryptionhandler.go @@ -95,6 +95,12 @@ func (w *cryptionResponseWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) } +func (w *cryptionResponseWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + func (w *cryptionResponseWriter) flush(key []byte) { if w.buf.Len() == 0 { return diff --git a/rest/handler/cryptionhandler_test.go b/rest/handler/cryptionhandler_test.go index 0be4697e..f7aab395 100644 --- a/rest/handler/cryptionhandler_test.go +++ b/rest/handler/cryptionhandler_test.go @@ -87,3 +87,20 @@ func TestCryptionHandlerWriteHeader(t *testing.T) { handler.ServeHTTP(recorder, req) assert.Equal(t, http.StatusServiceUnavailable, recorder.Code) } + +func TestCryptionHandlerFlush(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/any", nil) + handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(respText)) + + flusher, ok := w.(http.Flusher) + assert.Equal(t, ok, true) + flusher.Flush() + })) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + + expect, err := codec.EcbEncrypt(aesKey, []byte(respText)) + assert.Nil(t, err) + assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) +} diff --git a/rest/handler/loghandler.go b/rest/handler/loghandler.go index a45991de..1f7da6ee 100644 --- a/rest/handler/loghandler.go +++ b/rest/handler/loghandler.go @@ -38,6 +38,12 @@ func (w *LoggedResponseWriter) WriteHeader(code int) { w.code = code } +func (w *LoggedResponseWriter) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +} + func LogHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timer := utils.NewElapsedTimer() @@ -81,6 +87,12 @@ func (w *DetailLoggedResponseWriter) WriteHeader(code int) { w.writer.WriteHeader(code) } +func (w *DetailLoggedResponseWriter) Flush() { + if flusher, ok := http.ResponseWriter(w.writer).(http.Flusher); ok { + flusher.Flush() + } +} + func DetailedLogHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { timer := utils.NewElapsedTimer() diff --git a/rest/handler/loghandler_test.go b/rest/handler/loghandler_test.go index 4cc3dd42..24d1643b 100644 --- a/rest/handler/loghandler_test.go +++ b/rest/handler/loghandler_test.go @@ -30,6 +30,10 @@ func TestLogHandler(t *testing.T) { w.WriteHeader(http.StatusServiceUnavailable) _, err := w.Write([]byte("content")) assert.Nil(t, err) + + flusher, ok := w.(http.Flusher) + assert.Equal(t, ok, true) + flusher.Flush() })) resp := httptest.NewRecorder() diff --git a/rest/internal/security/withcoderesponsewriter.go b/rest/internal/security/withcoderesponsewriter.go index 41d61fd9..795e1b2d 100644 --- a/rest/internal/security/withcoderesponsewriter.go +++ b/rest/internal/security/withcoderesponsewriter.go @@ -19,3 +19,9 @@ func (w *WithCodeResponseWriter) WriteHeader(code int) { w.Writer.WriteHeader(code) w.Code = code } + +func (w *WithCodeResponseWriter) Flush() { + if flusher, ok := w.Writer.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/rest/internal/security/withcoderesponsewriter_test.go b/rest/internal/security/withcoderesponsewriter_test.go new file mode 100644 index 00000000..229878cd --- /dev/null +++ b/rest/internal/security/withcoderesponsewriter_test.go @@ -0,0 +1,33 @@ +package security + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithCodeResponseWriter(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cw := &WithCodeResponseWriter{Writer: w} + + cw.Header().Set("X-Test", "test") + cw.WriteHeader(http.StatusServiceUnavailable) + assert.Equal(t, cw.Code, http.StatusServiceUnavailable) + + _, err := cw.Write([]byte("content")) + assert.Nil(t, err) + + flusher, ok := http.ResponseWriter(cw).(http.Flusher) + assert.Equal(t, ok, true) + flusher.Flush() + }) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +}