From 3b0a2adf296791bdbd0c1b0fba6399910c1fac82 Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sun, 17 Aug 2025 04:08:24 +0800 Subject: [PATCH] Context fixes --- forged/internal/incoming/hooks/hooks.go | 13 ++++++++++++- forged/internal/incoming/lmtp/config.go | 12 +++++++++++- forged/internal/incoming/ssh/ssh.go | 14 +++++++++++++- forged/internal/incoming/web/web.go | 14 +++++++++++++- forged/internal/server/server.go | 46 ++++++++++++++++++++++++++-------------------- forged/main.go | 4 ++-- diff --git a/forged/internal/incoming/hooks/hooks.go b/forged/internal/incoming/hooks/hooks.go index 18b97426ecfa50ef16c0ce400af3ac0f5dc4cf29..c02140608d5b7ad312cf2fbe2f32ce048058927f 100644 --- a/forged/internal/incoming/hooks/hooks.go +++ b/forged/internal/incoming/hooks/hooks.go @@ -1,6 +1,7 @@ package hooks import ( + "context" "fmt" "net" @@ -40,15 +41,25 @@ executablesPath: config.Execs, } } -func (server *Server) Run() error { +func (server *Server) Run(ctx context.Context) error { listener, _, err := misc.ListenUnixSocket(server.socketPath) if err != nil { return fmt.Errorf("listen unix socket for hooks: %w", err) } + defer func() { + _ = listener.Close() + }() + + go func() { + <-ctx.Done() + _ = listener.Close() + // TODO: Log the error + }() for { conn, err := listener.Accept() if err != nil { + // TODO: Handle errors caused by context cancel return fmt.Errorf("accept conn: %w", err) } diff --git a/forged/internal/incoming/lmtp/config.go b/forged/internal/incoming/lmtp/config.go index e6db5a693bedff5820c5507ab4fcebc1d463c93b..c622648488bcbe029c8335c7c323f80cbd76ee01 100644 --- a/forged/internal/incoming/lmtp/config.go +++ b/forged/internal/incoming/lmtp/config.go @@ -1,6 +1,7 @@ package lmtp import ( + "context" "fmt" "net" @@ -33,11 +34,20 @@ readTimeout: config.ReadTimeout, } } -func (server *Server) Run() error { +func (server *Server) Run(ctx context.Context) error { listener, _, err := misc.ListenUnixSocket(server.socket) if err != nil { return fmt.Errorf("listen unix socket for LMTP: %w", err) } + defer func() { + _ = listener.Close() + }() + + go func() { + <-ctx.Done() + _ = listener.Close() + // TODO: Log the error + }() for { conn, err := listener.Accept() diff --git a/forged/internal/incoming/ssh/ssh.go b/forged/internal/incoming/ssh/ssh.go index b0269e5afa0f15d1211afaad1e0360aa452e932c..efa24f11ab828bcea93e34bdf54bc5033e164a9e 100644 --- a/forged/internal/incoming/ssh/ssh.go +++ b/forged/internal/incoming/ssh/ssh.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "fmt" "os" @@ -53,8 +54,19 @@ return } -func (server *Server) Run() (err error) { +func (server *Server) Run(ctx context.Context) (err error) { listener, err := misc.Listen(server.net, server.addr) + defer func() { + _ = listener.Close() + }() + + go func() { + <-ctx.Done() + _ = server.gliderServer.Close() + _ = listener.Close() // unnecessary? + // TODO: Log the error + }() + if err = server.gliderServer.Serve(listener); err != nil { return fmt.Errorf("serve SSH: %w", err) } diff --git a/forged/internal/incoming/web/web.go b/forged/internal/incoming/web/web.go index e4675f032af19873c26c29941dd2b55d726a9251..a95b650f73ad2ecbcc2b101aec9887a8f9e80627 100644 --- a/forged/internal/incoming/web/web.go +++ b/forged/internal/incoming/web/web.go @@ -1,6 +1,7 @@ package web import ( + "context" "fmt" "net/http" "time" @@ -48,8 +49,19 @@ }, } } -func (server *Server) Run() (err error) { +func (server *Server) Run(ctx context.Context) (err error) { listener, err := misc.Listen(server.net, server.addr) + defer func() { + _ = listener.Close() + }() + + go func() { + <-ctx.Done() + _ = server.httpServer.Close() + _ = listener.Close() // unnecessary? + // TODO: Log the error + }() + if err = server.httpServer.Serve(listener); err != nil { return fmt.Errorf("serve web: %w", err) } diff --git a/forged/internal/server/server.go b/forged/internal/server/server.go index b3cd92a7b3e02cdc8bf946c1ae7703e9176e6e74..0b14e62bf28951cef4f6f23985dddd2dae8c5723 100644 --- a/forged/internal/server/server.go +++ b/forged/internal/server/server.go @@ -3,7 +3,6 @@ import ( "context" "fmt" - "log" "go.lindenii.runxiyu.org/forge/forged/internal/config" "go.lindenii.runxiyu.org/forge/forged/internal/database" @@ -29,7 +28,7 @@ Version string } } -func New(ctx context.Context, configPath string) (server *Server, err error) { +func New(configPath string) (server *Server, err error) { server = &Server{} server.config, err = config.Open(configPath) @@ -37,11 +36,6 @@ if err != nil { return server, fmt.Errorf("open config: %w", err) } - server.database, err = database.Open(ctx, server.config.DB) - if err != nil { - return server, fmt.Errorf("open database: %w", err) - } - server.hookServer = hooks.New(server.config.Hooks) server.lmtpServer = lmtp.New(server.config.LMTP) server.webServer = web.New(server.config.Web) @@ -53,33 +47,45 @@ return server, nil } -func (s *Server) Run() error { +func (server *Server) Run(ctx context.Context) (err error) { // TODO: Not running git2d because it should be run separately. // This needs to be documented somewhere, hence a TODO here for now. + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + server.database, err = database.Open(subCtx, server.config.DB) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + + errCh := make(chan error) + go func() { - if err := s.hookServer.Run(); err != nil { - log.Fatalf("run hook server: %v", err) - } + err := server.hookServer.Run(subCtx) + errCh <- err }() go func() { - if err := s.lmtpServer.Run(); err != nil { - log.Fatalf("run LMTP server: %v", err) - } + err := server.lmtpServer.Run(subCtx) + errCh <- err }() go func() { - if err := s.webServer.Run(); err != nil { - log.Fatalf("run web server: %v", err) - } + err := server.webServer.Run(subCtx) + errCh <- err }() go func() { - if err := s.sshServer.Run(); err != nil { - log.Fatalf("run SSH server: %v", err) - } + err := server.sshServer.Run(subCtx) + errCh <- err }() + + select { + case err := <-errCh: + return fmt.Errorf("server error: %w", err) + case <-ctx.Done(): + } return nil } diff --git a/forged/main.go b/forged/main.go index e9609f327f6e1ace508b36d19c052cb636bfb983..38e22ff24725e0e43f4016cac5b74fbdd4bc122f 100644 --- a/forged/main.go +++ b/forged/main.go @@ -19,10 +19,10 @@ "path to configuration file", ) flag.Parse() - s, err := server.New(context.Background(), *configPath) + s, err := server.New(*configPath) if err != nil { panic(err) } - panic(s.Run()) + panic(s.Run(context.Background())) } -- 2.48.1