diff options
-rw-r--r-- | pkg/atlas/config.go | 3 | ||||
-rw-r--r-- | pkg/atlas/server.go | 57 | ||||
-rw-r--r-- | pkg/atlas/util.go | 45 |
3 files changed, 103 insertions, 2 deletions
diff --git a/pkg/atlas/config.go b/pkg/atlas/config.go index 8e68bb4..cc48a9c 100644 --- a/pkg/atlas/config.go +++ b/pkg/atlas/config.go @@ -176,7 +176,8 @@ type Config struct { // The path to use for static website files. If a file named redirects.json // exists, it is read at startup, reloaded on SIGHUP, and used as a mapping - // of top-level names to URLs. + // of top-level names to URLs. Custom error pages can be named + // {status}.html. Web string `env:"ATLAS_WEB"` // The path to the IP2Location database, which should contain at least the diff --git a/pkg/atlas/server.go b/pkg/atlas/server.go index 3c881dd..fe52441 100644 --- a/pkg/atlas/server.go +++ b/pkg/atlas/server.go @@ -71,6 +71,7 @@ func NewServer(c *Config) (*Server, error) { if c.Web != "" { if p, err := filepath.Abs(c.Web); err == nil { var redirects sync.Map + var errpages sync.Map var err1 error reload := func() { @@ -92,18 +93,72 @@ func NewServer(c *Config) (*Server, error) { redirects.Store(strings.Trim(p, "/"), u) } } + if es, err := os.ReadDir(filepath.Join(p)); err != nil { + if !errors.Is(err, os.ErrNotExist) { + err1 = fmt.Errorf("read error pages: %w", err) + return + } + } else { + sc := map[int][]byte{} + for _, e := range es { + a, b, _ := strings.Cut(e.Name(), ".") + if b != "html" { + continue + } + s, err := strconv.ParseUint(a, 10, 64) + if err != nil || s < 400 || s >= 600 { + continue + } + c, err := os.ReadFile(filepath.Join(p, e.Name())) + if err != nil { + err1 = fmt.Errorf("read error page for %d: %w", s, err) + return + } + sc[int(s)] = c + } + errpages.Range(func(key, _ any) bool { + errpages.Delete(key) + return true + }) + for s, c := range sc { + errpages.Store(s, c) + } + } + } if reload(); err1 != nil { return nil, fmt.Errorf("initialize web: %w", err) } s.reload = append(s.reload, reload) + fsrv := &statusInterceptor{ + Handler: http.FileServer(http.Dir(c.Web)), + Error: func(s int) http.Handler { + switch s { + case http.StatusNotFound, http.StatusInternalServerError, http.StatusForbidden: + if c, ok := errpages.Load(s); ok { + b := c.([]byte) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "private, no-cache, no-store, max-age=0, must-revalidate") + w.Header().Set("Expires", "0") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(b))) + w.WriteHeader(s) + w.Write(b) + }) + } + } + return nil + }, + } + s.Web = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if v, ok := redirects.Load(strings.Trim(r.URL.Path, "/")); ok { http.Redirect(w, r, v.(string), http.StatusTemporaryRedirect) return } - http.FileServer(http.Dir(c.Web)).ServeHTTP(w, r) + fsrv.ServeHTTP(w, r) }) } else { return nil, fmt.Errorf("initialize web: resolve path: %w", err) diff --git a/pkg/atlas/util.go b/pkg/atlas/util.go index 7ea8b14..7794483 100644 --- a/pkg/atlas/util.go +++ b/pkg/atlas/util.go @@ -120,3 +120,48 @@ func (ms *middlewares) Then(h http.Handler) http.Handler { } return h } + +type statusInterceptor struct { + Handler http.Handler + Error func(s int) http.Handler +} + +type statusInterceptorResponse struct { + i *statusInterceptor + w http.ResponseWriter + r *http.Request + hdr bool + done bool +} + +func (i *statusInterceptor) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w = &statusInterceptorResponse{i: i, w: w, r: r} + i.Handler.ServeHTTP(w, r) +} + +func (i *statusInterceptorResponse) Header() http.Header { + return i.w.Header() +} + +func (i *statusInterceptorResponse) Write(b []byte) (int, error) { + if i.done { + return 0, nil + } + i.hdr = true + return i.w.Write(b) +} + +func (i *statusInterceptorResponse) WriteHeader(statusCode int) { + if i.done { + return + } + if !i.hdr { + if h := i.i.Error(statusCode); h != nil { + i.done, i.hdr = true, true + h.ServeHTTP(i.w, i.r) + return + } + } + i.hdr = true + i.w.WriteHeader(statusCode) +} |