aboutsummaryrefslogtreecommitdiff
path: root/pkg/atlas
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/atlas')
-rw-r--r--pkg/atlas/config.go3
-rw-r--r--pkg/atlas/server.go57
-rw-r--r--pkg/atlas/util.go45
3 files changed, 103 insertions, 2 deletions
diff --git a/pkg/atlas/config.go b/pkg/atlas/config.go
index 8e68bb4..cc48a9c 100644
--- a/pkg/atlas/config.go
+++ b/pkg/atlas/config.go
@@ -176,7 +176,8 @@ type Config struct {
// The path to use for static website files. If a file named redirects.json
// exists, it is read at startup, reloaded on SIGHUP, and used as a mapping
- // of top-level names to URLs.
+ // of top-level names to URLs. Custom error pages can be named
+ // {status}.html.
Web string `env:"ATLAS_WEB"`
// The path to the IP2Location database, which should contain at least the
diff --git a/pkg/atlas/server.go b/pkg/atlas/server.go
index 3c881dd..fe52441 100644
--- a/pkg/atlas/server.go
+++ b/pkg/atlas/server.go
@@ -71,6 +71,7 @@ func NewServer(c *Config) (*Server, error) {
if c.Web != "" {
if p, err := filepath.Abs(c.Web); err == nil {
var redirects sync.Map
+ var errpages sync.Map
var err1 error
reload := func() {
@@ -92,18 +93,72 @@ func NewServer(c *Config) (*Server, error) {
redirects.Store(strings.Trim(p, "/"), u)
}
}
+ if es, err := os.ReadDir(filepath.Join(p)); err != nil {
+ if !errors.Is(err, os.ErrNotExist) {
+ err1 = fmt.Errorf("read error pages: %w", err)
+ return
+ }
+ } else {
+ sc := map[int][]byte{}
+ for _, e := range es {
+ a, b, _ := strings.Cut(e.Name(), ".")
+ if b != "html" {
+ continue
+ }
+ s, err := strconv.ParseUint(a, 10, 64)
+ if err != nil || s < 400 || s >= 600 {
+ continue
+ }
+ c, err := os.ReadFile(filepath.Join(p, e.Name()))
+ if err != nil {
+ err1 = fmt.Errorf("read error page for %d: %w", s, err)
+ return
+ }
+ sc[int(s)] = c
+ }
+ errpages.Range(func(key, _ any) bool {
+ errpages.Delete(key)
+ return true
+ })
+ for s, c := range sc {
+ errpages.Store(s, c)
+ }
+ }
+
}
if reload(); err1 != nil {
return nil, fmt.Errorf("initialize web: %w", err)
}
s.reload = append(s.reload, reload)
+ fsrv := &statusInterceptor{
+ Handler: http.FileServer(http.Dir(c.Web)),
+ Error: func(s int) http.Handler {
+ switch s {
+ case http.StatusNotFound, http.StatusInternalServerError, http.StatusForbidden:
+ if c, ok := errpages.Load(s); ok {
+ b := c.([]byte)
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Cache-Control", "private, no-cache, no-store, max-age=0, must-revalidate")
+ w.Header().Set("Expires", "0")
+ w.Header().Set("Pragma", "no-cache")
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.Header().Set("Content-Length", strconv.Itoa(len(b)))
+ w.WriteHeader(s)
+ w.Write(b)
+ })
+ }
+ }
+ return nil
+ },
+ }
+
s.Web = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if v, ok := redirects.Load(strings.Trim(r.URL.Path, "/")); ok {
http.Redirect(w, r, v.(string), http.StatusTemporaryRedirect)
return
}
- http.FileServer(http.Dir(c.Web)).ServeHTTP(w, r)
+ fsrv.ServeHTTP(w, r)
})
} else {
return nil, fmt.Errorf("initialize web: resolve path: %w", err)
diff --git a/pkg/atlas/util.go b/pkg/atlas/util.go
index 7ea8b14..7794483 100644
--- a/pkg/atlas/util.go
+++ b/pkg/atlas/util.go
@@ -120,3 +120,48 @@ func (ms *middlewares) Then(h http.Handler) http.Handler {
}
return h
}
+
+type statusInterceptor struct {
+ Handler http.Handler
+ Error func(s int) http.Handler
+}
+
+type statusInterceptorResponse struct {
+ i *statusInterceptor
+ w http.ResponseWriter
+ r *http.Request
+ hdr bool
+ done bool
+}
+
+func (i *statusInterceptor) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ w = &statusInterceptorResponse{i: i, w: w, r: r}
+ i.Handler.ServeHTTP(w, r)
+}
+
+func (i *statusInterceptorResponse) Header() http.Header {
+ return i.w.Header()
+}
+
+func (i *statusInterceptorResponse) Write(b []byte) (int, error) {
+ if i.done {
+ return 0, nil
+ }
+ i.hdr = true
+ return i.w.Write(b)
+}
+
+func (i *statusInterceptorResponse) WriteHeader(statusCode int) {
+ if i.done {
+ return
+ }
+ if !i.hdr {
+ if h := i.i.Error(statusCode); h != nil {
+ i.done, i.hdr = true, true
+ h.ServeHTTP(i.w, i.r)
+ return
+ }
+ }
+ i.hdr = true
+ i.w.WriteHeader(statusCode)
+}