diff --git a/rest/internal/fileserver/filehandler.go b/rest/internal/fileserver/filehandler.go index 00e49a47..5a0d15f7 100644 --- a/rest/internal/fileserver/filehandler.go +++ b/rest/internal/fileserver/filehandler.go @@ -5,8 +5,9 @@ import ( "strings" ) -func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc { - fileServer := http.FileServer(http.Dir(dir)) +// Middleware returns a middleware that serves files from the given file system. +func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc { + fileServer := http.FileServer(fs) pathWithTrailSlash := ensureTrailingSlash(path) pathWithoutTrailSlash := ensureNoTrailingSlash(path) diff --git a/rest/internal/fileserver/filehandler_test.go b/rest/internal/fileserver/filehandler_test.go index e3319078..b14466be 100644 --- a/rest/internal/fileserver/filehandler_test.go +++ b/rest/internal/fileserver/filehandler_test.go @@ -44,7 +44,7 @@ func TestMiddleware(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - middleware := Middleware(tt.path, tt.dir) + middleware := Middleware(tt.path, http.Dir(tt.dir)) nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }) diff --git a/rest/server.go b/rest/server.go index 3f4fcec9..747bb2c3 100644 --- a/rest/server.go +++ b/rest/server.go @@ -172,9 +172,9 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt } // WithFileServer returns a RunOption to serve files from given dir with given path. -func WithFileServer(path, dir string) RunOption { +func WithFileServer(path string, fs http.FileSystem) RunOption { return func(server *Server) { - server.router = newFileServingRouter(server.router, path, dir) + server.router = newFileServingRouter(server.router, path, fs) } } @@ -351,10 +351,10 @@ type fileServingRouter struct { middleware Middleware } -func newFileServingRouter(router httpx.Router, path, dir string) httpx.Router { +func newFileServingRouter(router httpx.Router, path string, fs http.FileSystem) httpx.Router { return &fileServingRouter{ Router: router, - middleware: fileserver.Middleware(path, dir), + middleware: fileserver.Middleware(path, fs), } } diff --git a/rest/server_test.go b/rest/server_test.go index c84db046..3f01fd3f 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -2,8 +2,10 @@ package rest import ( "crypto/tls" + "embed" "fmt" "io" + "io/fs" "net/http" "net/http/httptest" "os" @@ -21,6 +23,11 @@ import ( "github.com/zeromicro/go-zero/rest/router" ) +const ( + exampleContent = "example content" + sampleContent = "sample content" +) + func TestNewServer(t *testing.T) { logtest.Discard(t) @@ -199,7 +206,7 @@ func TestWithFileServerMiddleware(t *testing.T) { dir: "./testdata", requestPath: "/assets/example.txt", expectedStatus: http.StatusOK, - expectedContent: "example content", + expectedContent: exampleContent, }, { name: "Pass through non-matching path", @@ -214,13 +221,13 @@ func TestWithFileServerMiddleware(t *testing.T) { dir: "testdata", requestPath: "/static/sample.txt", expectedStatus: http.StatusOK, - expectedContent: "sample content", + expectedContent: sampleContent, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := MustNewServer(RestConf{}, WithFileServer(tt.path, tt.dir)) + server := MustNewServer(RestConf{}, WithFileServer(tt.path, http.Dir(tt.dir))) req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) rr := httptest.NewRecorder() @@ -688,3 +695,18 @@ Port: 54321 }) } } + +//go:embed testdata +var content embed.FS + +func TestServerEmbedFileSystem(t *testing.T) { + filesys, err := fs.Sub(content, "testdata") + assert.NoError(t, err) + + server := MustNewServer(RestConf{}, WithFileServer("/assets", http.FS(filesys))) + req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody) + assert.Nil(t, err) + rr := httptest.NewRecorder() + server.ServeHTTP(rr, req) + assert.Equal(t, sampleContent, rr.Body.String()) +}