diff --git a/rest/internal/fileserver/filehandler.go b/rest/internal/fileserver/filehandler.go new file mode 100644 index 00000000..00e49a47 --- /dev/null +++ b/rest/internal/fileserver/filehandler.go @@ -0,0 +1,39 @@ +package fileserver + +import ( + "net/http" + "strings" +) + +func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc { + fileServer := http.FileServer(http.Dir(dir)) + pathWithTrailSlash := ensureTrailingSlash(path) + pathWithoutTrailSlash := ensureNoTrailingSlash(path) + + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) { + r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash) + fileServer.ServeHTTP(w, r) + } else { + next(w, r) + } + } + } +} + +func ensureTrailingSlash(path string) string { + if strings.HasSuffix(path, "/") { + return path + } + + return path + "/" +} + +func ensureNoTrailingSlash(path string) string { + if strings.HasSuffix(path, "/") { + return path[:len(path)-1] + } + + return path +} diff --git a/rest/internal/fileserver/filehandler_test.go b/rest/internal/fileserver/filehandler_test.go new file mode 100644 index 00000000..e3319078 --- /dev/null +++ b/rest/internal/fileserver/filehandler_test.go @@ -0,0 +1,99 @@ +package fileserver + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMiddleware(t *testing.T) { + tests := []struct { + name string + path string + dir string + requestPath string + expectedStatus int + expectedContent string + }{ + { + name: "Serve static file", + path: "/static/", + dir: "./testdata", + requestPath: "/static/example.txt", + expectedStatus: http.StatusOK, + expectedContent: "1", + }, + { + name: "Pass through non-matching path", + path: "/static/", + dir: "./testdata", + requestPath: "/other/path", + expectedStatus: http.StatusNotFound, + }, + { + name: "Directory with trailing slash", + path: "/assets", + dir: "testdata", + requestPath: "/assets/sample.txt", + expectedStatus: http.StatusOK, + expectedContent: "2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := Middleware(tt.path, tt.dir) + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", tt.requestPath, nil) + rr := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + if len(tt.expectedContent) > 0 { + assert.Equal(t, tt.expectedContent, rr.Body.String()) + } + }) + } +} + +func TestEnsureTrailingSlash(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"path", "path/"}, + {"path/", "path/"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := ensureTrailingSlash(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestEnsureNoTrailingSlash(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"path", "path"}, + {"path/", "path"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := ensureNoTrailingSlash(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/rest/internal/fileserver/testdata/example.txt b/rest/internal/fileserver/testdata/example.txt new file mode 100644 index 00000000..56a6051c --- /dev/null +++ b/rest/internal/fileserver/testdata/example.txt @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/rest/internal/fileserver/testdata/sample.txt b/rest/internal/fileserver/testdata/sample.txt new file mode 100644 index 00000000..d8263ee9 --- /dev/null +++ b/rest/internal/fileserver/testdata/sample.txt @@ -0,0 +1 @@ +2 \ No newline at end of file diff --git a/rest/server.go b/rest/server.go index bbf6ff39..937dde33 100644 --- a/rest/server.go +++ b/rest/server.go @@ -13,6 +13,7 @@ import ( "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal" "github.com/zeromicro/go-zero/rest/internal/cors" + "github.com/zeromicro/go-zero/rest/internal/fileserver" "github.com/zeromicro/go-zero/rest/router" ) @@ -170,6 +171,13 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt } } +// WithFileServer returns a RunOption to serve files from given dir with given path. +func WithFileServer(path, dir string) RunOption { + return func(server *Server) { + server.router = newFileServingRouter(server.router, path, dir) + } +} + // WithJwt returns a func to enable jwt authentication in given route. func WithJwt(secret string) RouteOption { return func(r *featuredRoutes) { @@ -337,3 +345,19 @@ func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...s func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.middleware(c.Router.ServeHTTP)(w, r) } + +type fileServingRouter struct { + httpx.Router + middleware Middleware +} + +func newFileServingRouter(router httpx.Router, path, dir string) httpx.Router { + return &fileServingRouter{ + Router: router, + middleware: fileserver.Middleware(path, dir), + } +} + +func (f *fileServingRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f.middleware(f.Router.ServeHTTP)(w, r) +} diff --git a/rest/server_test.go b/rest/server_test.go index 3a4bdded..c84db046 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -184,6 +184,56 @@ func TestWithMiddleware(t *testing.T) { }, m) } +func TestWithFileServerMiddleware(t *testing.T) { + tests := []struct { + name string + path string + dir string + requestPath string + expectedStatus int + expectedContent string + }{ + { + name: "Serve static file", + path: "/assets/", + dir: "./testdata", + requestPath: "/assets/example.txt", + expectedStatus: http.StatusOK, + expectedContent: "example content", + }, + { + name: "Pass through non-matching path", + path: "/static/", + dir: "./testdata", + requestPath: "/other/path", + expectedStatus: http.StatusNotFound, + }, + { + name: "Directory with trailing slash", + path: "/static", + dir: "testdata", + requestPath: "/static/sample.txt", + expectedStatus: http.StatusOK, + expectedContent: "sample content", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := MustNewServer(RestConf{}, WithFileServer(tt.path, tt.dir)) + req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + if len(tt.expectedContent) > 0 { + assert.Equal(t, tt.expectedContent, rr.Body.String()) + } + }) + } +} + func TestMultiMiddlewares(t *testing.T) { m := make(map[string]string) rt := router.NewRouter() diff --git a/rest/testdata/example.txt b/rest/testdata/example.txt new file mode 100644 index 00000000..76b3034f --- /dev/null +++ b/rest/testdata/example.txt @@ -0,0 +1 @@ +example content \ No newline at end of file diff --git a/rest/testdata/sample.txt b/rest/testdata/sample.txt new file mode 100644 index 00000000..0d6f364b --- /dev/null +++ b/rest/testdata/sample.txt @@ -0,0 +1 @@ +sample content \ No newline at end of file