From 4a14164be1db9e50af96ea9fc8dbc767a3747e8a Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Thu, 18 Jul 2024 23:15:03 +0800 Subject: [PATCH] feat: handle using root as the path of file server (#4255) --- rest/internal/fileserver/filehandler.go | 45 ++++++++++++++++++-- rest/internal/fileserver/filehandler_test.go | 39 +++++++++++++---- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/rest/internal/fileserver/filehandler.go b/rest/internal/fileserver/filehandler.go index 5a0d15f7..163cdea4 100644 --- a/rest/internal/fileserver/filehandler.go +++ b/rest/internal/fileserver/filehandler.go @@ -3,18 +3,19 @@ package fileserver import ( "net/http" "strings" + "sync" ) // Middleware returns a middleware that serves files from the given file system. func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc { fileServer := http.FileServer(fs) - pathWithTrailSlash := ensureTrailingSlash(path) pathWithoutTrailSlash := ensureNoTrailingSlash(path) + canServe := createServeChecker(path, fs) return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) { - r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash) + if canServe(r) { + r.URL.Path = r.URL.Path[len(pathWithoutTrailSlash):] fileServer.ServeHTTP(w, r) } else { next(w, r) @@ -23,6 +24,44 @@ func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.Han } } +func createFileChecker(fs http.FileSystem) func(string) bool { + var lock sync.RWMutex + fileChecker := make(map[string]bool) + + return func(path string) bool { + lock.RLock() + exist, ok := fileChecker[path] + lock.RUnlock() + if ok { + return exist + } + + lock.Lock() + defer lock.Unlock() + + file, err := fs.Open(path) + exist = err == nil + fileChecker[path] = exist + if err != nil { + return false + } + + _ = file.Close() + return true + } +} + +func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool { + pathWithTrailSlash := ensureTrailingSlash(path) + fileChecker := createFileChecker(fs) + + return func(r *http.Request) bool { + return r.Method == http.MethodGet && + strings.HasPrefix(r.URL.Path, pathWithTrailSlash) && + fileChecker(r.URL.Path[len(pathWithTrailSlash):]) + } +} + func ensureTrailingSlash(path string) string { if strings.HasSuffix(path, "/") { return path diff --git a/rest/internal/fileserver/filehandler_test.go b/rest/internal/fileserver/filehandler_test.go index b14466be..37308f72 100644 --- a/rest/internal/fileserver/filehandler_test.go +++ b/rest/internal/fileserver/filehandler_test.go @@ -30,7 +30,7 @@ func TestMiddleware(t *testing.T) { path: "/static/", dir: "./testdata", requestPath: "/other/path", - expectedStatus: http.StatusNotFound, + expectedStatus: http.StatusAlreadyReported, }, { name: "Directory with trailing slash", @@ -40,25 +40,48 @@ func TestMiddleware(t *testing.T) { expectedStatus: http.StatusOK, expectedContent: "2", }, + { + name: "Not exist file", + path: "/assets", + dir: "testdata", + requestPath: "/assets/not-exist.txt", + expectedStatus: http.StatusAlreadyReported, + }, + { + name: "Not exist file in root", + path: "/", + dir: "testdata", + requestPath: "/not-exist.txt", + expectedStatus: http.StatusAlreadyReported, + }, + { + name: "websocket request", + path: "/", + dir: "testdata", + requestPath: "/ws", + expectedStatus: http.StatusAlreadyReported, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { middleware := Middleware(tt.path, http.Dir(tt.dir)) nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) + w.WriteHeader(http.StatusAlreadyReported) }) handlerToTest := middleware(nextHandler) - req := httptest.NewRequest("GET", tt.requestPath, nil) - rr := httptest.NewRecorder() + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) + rr := httptest.NewRecorder() - handlerToTest.ServeHTTP(rr, req) + handlerToTest.ServeHTTP(rr, req) - assert.Equal(t, tt.expectedStatus, rr.Code) - if len(tt.expectedContent) > 0 { - assert.Equal(t, tt.expectedContent, rr.Body.String()) + assert.Equal(t, tt.expectedStatus, rr.Code) + if len(tt.expectedContent) > 0 { + assert.Equal(t, tt.expectedContent, rr.Body.String()) + } } }) }