From 308ce6c10ce77835a9b4d2ca9a17d449260c6adb Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sun, 17 Aug 2025 04:34:16 +0800 Subject: [PATCH] Add shutdown timeouts --- forged/internal/incoming/ssh/ssh.go | 39 ++++++++++++++++++++++----------------- forged/internal/incoming/web/web.go | 42 +++++++++++++++++++++++------------------- diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go index 77812d12689997223f2750757b34bf5ebee25a1d..9f9bdff0f30a4421b6247e70a5409234333b354a 100644 --- a/forged/internal/incoming/ssh/ssh.go +++ b/forged/internal/incoming/ssh/ssh.go @@ -5,6 +5,7 @@ "context" "errors" "fmt" "os" + "time" gliderssh "github.com/gliderlabs/ssh" "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" @@ -12,27 +13,30 @@ gossh "golang.org/x/crypto/ssh" ) type Config struct { - Net string `scfg:"net"` - Addr string `scfg:"addr"` - Key string `scfg:"key"` - Root string `scfg:"root"` + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Key string `scfg:"key"` + Root string `scfg:"root"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` } type Server struct { - gliderServer *gliderssh.Server - privkey gossh.Signer - pubkeyString string - pubkeyFP string - net string - addr string - root string + gliderServer *gliderssh.Server + privkey gossh.Signer + pubkeyString string + pubkeyFP string + net string + addr string + root string + shutdownTimeout uint32 } func New(config Config) (server *Server, err error) { server = &Server{ - net: config.Net, - addr: config.Addr, - root: config.Root, + net: config.Net, + addr: config.Addr, + root: config.Root, + shutdownTimeout: config.ShutdownTimeout, } var privkeyBytes []byte @@ -63,9 +67,10 @@ }() go func() { <-ctx.Done() - _ = server.gliderServer.Close() - _ = listener.Close() // unnecessary? - // TODO: Log the error + shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.gliderServer.Shutdown(shCtx) + _ = listener.Close() }() if err = server.gliderServer.Serve(listener); err != nil { diff --git a/forged/internal/incoming/web/web.go b/forged/internal/incoming/web/web.go index f66ad64255865171e039c789b03eed1b70a8c7e5..391f6ff65a07cf5b54c3b28b9d602dd555c15626 100644 --- a/forged/internal/incoming/web/web.go +++ b/forged/internal/incoming/web/web.go @@ -10,10 +10,11 @@ "go.lindenii.runxiyu.org/forge/forged/internal/common/misc" ) type Server struct { - net string - addr string - root string - httpServer *http.Server + net string + addr string + root string + httpServer *http.Server + shutdownTimeout uint32 } type handler struct{} @@ -22,23 +23,25 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } type Config struct { - Net string `scfg:"net"` - Addr string `scfg:"addr"` - Root string `scfg:"root"` - CookieExpiry int `scfg:"cookie_expiry"` - ReadTimeout uint32 `scfg:"read_timeout"` - WriteTimeout uint32 `scfg:"write_timeout"` - IdleTimeout uint32 `scfg:"idle_timeout"` - MaxHeaderBytes int `scfg:"max_header_bytes"` - ReverseProxy bool `scfg:"reverse_proxy"` + Net string `scfg:"net"` + Addr string `scfg:"addr"` + Root string `scfg:"root"` + CookieExpiry int `scfg:"cookie_expiry"` + ReadTimeout uint32 `scfg:"read_timeout"` + WriteTimeout uint32 `scfg:"write_timeout"` + IdleTimeout uint32 `scfg:"idle_timeout"` + MaxHeaderBytes int `scfg:"max_header_bytes"` + ReverseProxy bool `scfg:"reverse_proxy"` + ShutdownTimeout uint32 `scfg:"shutdown_timeout"` } func New(config Config) (server *Server) { handler := &handler{} return &Server{ - net: config.Net, - addr: config.Addr, - root: config.Root, + net: config.Net, + addr: config.Addr, + root: config.Root, + shutdownTimeout: config.ShutdownTimeout, httpServer: &http.Server{ Handler: handler, ReadTimeout: time.Duration(config.ReadTimeout) * time.Second, @@ -57,9 +60,10 @@ }() go func() { <-ctx.Done() - _ = server.httpServer.Close() - _ = listener.Close() // unnecessary? - // TODO: Log the error + shCtx, cancel := context.WithTimeout(context.Background(), time.Duration(server.shutdownTimeout)*time.Second) + defer cancel() + _ = server.httpServer.Shutdown(shCtx) + _ = listener.Close() }() if err = server.httpServer.Serve(listener); err != nil { -- 2.48.1