feat: handle using root as the path of file server (#4255)

This commit is contained in:
Kevin Wan 2024-07-18 23:15:03 +08:00 committed by GitHub
parent 5dd6f2a43a
commit 4a14164be1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 11 deletions

View File

@ -3,18 +3,19 @@ package fileserver
import (
"net/http"
"strings"
"sync"
)
// Middleware returns a middleware that serves files from the given file system.
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(fs)
pathWithTrailSlash := ensureTrailingSlash(path)
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
canServe := createServeChecker(path, fs)
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash)
if canServe(r) {
r.URL.Path = r.URL.Path[len(pathWithoutTrailSlash):]
fileServer.ServeHTTP(w, r)
} else {
next(w, r)
@ -23,6 +24,44 @@ func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.Han
}
}
func createFileChecker(fs http.FileSystem) func(string) bool {
var lock sync.RWMutex
fileChecker := make(map[string]bool)
return func(path string) bool {
lock.RLock()
exist, ok := fileChecker[path]
lock.RUnlock()
if ok {
return exist
}
lock.Lock()
defer lock.Unlock()
file, err := fs.Open(path)
exist = err == nil
fileChecker[path] = exist
if err != nil {
return false
}
_ = file.Close()
return true
}
}
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
pathWithTrailSlash := ensureTrailingSlash(path)
fileChecker := createFileChecker(fs)
return func(r *http.Request) bool {
return r.Method == http.MethodGet &&
strings.HasPrefix(r.URL.Path, pathWithTrailSlash) &&
fileChecker(r.URL.Path[len(pathWithTrailSlash):])
}
}
func ensureTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path

View File

@ -30,7 +30,7 @@ func TestMiddleware(t *testing.T) {
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Directory with trailing slash",
@ -40,25 +40,48 @@ func TestMiddleware(t *testing.T) {
expectedStatus: http.StatusOK,
expectedContent: "2",
},
{
name: "Not exist file",
path: "/assets",
dir: "testdata",
requestPath: "/assets/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Not exist file in root",
path: "/",
dir: "testdata",
requestPath: "/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "websocket request",
path: "/",
dir: "testdata",
requestPath: "/ws",
expectedStatus: http.StatusAlreadyReported,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, http.Dir(tt.dir))
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.WriteHeader(http.StatusAlreadyReported)
})
handlerToTest := middleware(nextHandler)
req := httptest.NewRequest("GET", tt.requestPath, nil)
rr := httptest.NewRecorder()
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()
handlerToTest.ServeHTTP(rr, req)
handlerToTest.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
}
})
}