feat: support file server in rest (#4244)

This commit is contained in:
Kevin Wan 2024-07-13 19:58:35 +08:00 committed by GitHub
parent e776b5d8ab
commit ec86f22cd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 216 additions and 0 deletions

View File

@ -0,0 +1,39 @@
package fileserver
import (
"net/http"
"strings"
)
func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(http.Dir(dir))
pathWithTrailSlash := ensureTrailingSlash(path)
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash)
fileServer.ServeHTTP(w, r)
} else {
next(w, r)
}
}
}
}
func ensureTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path
}
return path + "/"
}
func ensureNoTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path[:len(path)-1]
}
return path
}

View File

@ -0,0 +1,99 @@
package fileserver
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMiddleware(t *testing.T) {
tests := []struct {
name string
path string
dir string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/static/",
dir: "./testdata",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Pass through non-matching path",
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
},
{
name: "Directory with trailing slash",
path: "/assets",
dir: "testdata",
requestPath: "/assets/sample.txt",
expectedStatus: http.StatusOK,
expectedContent: "2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, tt.dir)
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
handlerToTest := middleware(nextHandler)
req := httptest.NewRequest("GET", tt.requestPath, nil)
rr := httptest.NewRecorder()
handlerToTest.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
})
}
}
func TestEnsureTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path/"},
{"path/", "path/"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestEnsureNoTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path"},
{"path/", "path"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureNoTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -0,0 +1 @@
1

View File

@ -0,0 +1 @@
2

View File

@ -13,6 +13,7 @@ import (
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
"github.com/zeromicro/go-zero/rest/internal/cors"
"github.com/zeromicro/go-zero/rest/internal/fileserver"
"github.com/zeromicro/go-zero/rest/router"
)
@ -170,6 +171,13 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
}
}
// WithFileServer returns a RunOption to serve files from given dir with given path.
func WithFileServer(path, dir string) RunOption {
return func(server *Server) {
server.router = newFileServingRouter(server.router, path, dir)
}
}
// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption {
return func(r *featuredRoutes) {
@ -337,3 +345,19 @@ func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...s
func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.middleware(c.Router.ServeHTTP)(w, r)
}
type fileServingRouter struct {
httpx.Router
middleware Middleware
}
func newFileServingRouter(router httpx.Router, path, dir string) httpx.Router {
return &fileServingRouter{
Router: router,
middleware: fileserver.Middleware(path, dir),
}
}
func (f *fileServingRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
f.middleware(f.Router.ServeHTTP)(w, r)
}

View File

@ -184,6 +184,56 @@ func TestWithMiddleware(t *testing.T) {
}, m)
}
func TestWithFileServerMiddleware(t *testing.T) {
tests := []struct {
name string
path string
dir string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/assets/",
dir: "./testdata",
requestPath: "/assets/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "example content",
},
{
name: "Pass through non-matching path",
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
},
{
name: "Directory with trailing slash",
path: "/static",
dir: "testdata",
requestPath: "/static/sample.txt",
expectedStatus: http.StatusOK,
expectedContent: "sample content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := MustNewServer(RestConf{}, WithFileServer(tt.path, tt.dir))
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
})
}
}
func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
rt := router.NewRouter()

1
rest/testdata/example.txt vendored Normal file
View File

@ -0,0 +1 @@
example content

1
rest/testdata/sample.txt vendored Normal file
View File

@ -0,0 +1 @@
sample content