From 20b4fe0c59357a433042732d46e38da9c3d14c3b Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Sat, 05 Apr 2025 20:26:57 +0800 Subject: [PATCH] database shall no longer be a global variable --- acl.go | 2 +- config.go | 2 +- database.go | 8 +------- fedauth.go | 2 +- git_hooks_handle_linux.go | 6 +++--- git_hooks_handle_other.go | 6 +++--- git_misc.go | 4 ++-- http_auth.go | 4 ++-- http_handle_group_index.go | 12 ++++++------ http_handle_login.go | 4 ++-- http_handle_repo_contrib_index.go | 4 ++-- http_handle_repo_contrib_one.go | 4 ++-- http_handle_repo_info.go | 4 ++-- http_handle_repo_upload_pack.go | 2 +- http_server.go | 10 +++++----- lmtp_handle_patch.go | 4 ++-- lmtp_server.go | 2 +- server.go | 7 +++++++ ssh_handle_receive_pack.go | 2 +- users.go | 4 ++-- diff --git a/acl.go b/acl.go index 44cd04b644a5473d8ac825112cf4a9af520a1680..dfe128a3f7857e7885d4e12259aaee2907c87cb3 100644 --- a/acl.go +++ b/acl.go @@ -14,7 +14,7 @@ // given repo and a provided ssh public key. // // TODO: Revamp. func (s *server) getRepoInfo(ctx context.Context, groupPath []string, repoName, sshPubkey string) (repoID int, fsPath string, access bool, contribReq, userType string, userID int, err error) { - err = database.QueryRow(ctx, ` + err = s.database.QueryRow(ctx, ` WITH RECURSIVE group_path_cte AS ( -- Start: match the first name in the path where parent_group IS NULL SELECT diff --git a/config.go b/config.go index 1bbc3a104b89cd54b6aff7cb7a2092dfd4e46b89..773a223fac2c74d9f2d57662ef04130f8258b9a3 100644 --- a/config.go +++ b/config.go @@ -92,7 +92,7 @@ if s.config.DB.Type != "postgres" { return errors.New("unsupported database type") } - if database, err = pgxpool.New(context.Background(), s.config.DB.Conn); err != nil { + if s.database, err = pgxpool.New(context.Background(), s.config.DB.Conn); err != nil { return err } diff --git a/database.go b/database.go index 18e753fee640e9c8b1cee19eacbacd085223fe9b..1ea075301c2327063048b4dbb2a774798d88baa2 100644 --- a/database.go +++ b/database.go @@ -7,7 +7,6 @@ import ( "context" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" ) // TODO: All database handling logic in all request handlers must be revamped. @@ -16,18 +15,13 @@ // to exceptions if appropriate) so they get a consistent view of the database // at a single point. A failure to do so may cause things as serious as // privilege escalation. -// database serves as the primary database handle for this entire application. -// Transactions or single reads may be used from it. A [pgxpool.Pool] is -// necessary to safely use pgx concurrently; pgx.Conn, etc. are insufficient. -var database *pgxpool.Pool - // queryNameDesc is a helper function that executes a query and returns a // list of nameDesc results. The query must return two string arguments, i.e. a // name and a description. func (s *server) queryNameDesc(ctx context.Context, query string, args ...any) (result []nameDesc, err error) { var rows pgx.Rows - if rows, err = database.Query(ctx, query, args...); err != nil { + if rows, err = s.database.Query(ctx, query, args...); err != nil { return nil, err } defer rows.Close() diff --git a/fedauth.go b/fedauth.go index 46290e502c5b60b2027a6cab0ab06dfb6c206c51..43cb4e32cd9c02152d9a95c3fa8bda8777291589 100644 --- a/fedauth.go +++ b/fedauth.go @@ -77,7 +77,7 @@ return false, nil } var txn pgx.Tx - if txn, err = database.Begin(ctx); err != nil { + if txn, err = s.database.Begin(ctx); err != nil { return false, err } defer func() { diff --git a/git_hooks_handle_linux.go b/git_hooks_handle_linux.go index 37afba106405c98ab1da3c6031e3ed685daf31b0..ca262e316dda5c319d49158bc4cbd3dfa5bdb4a5 100644 --- a/git_hooks_handle_linux.go +++ b/git_hooks_handle_linux.go @@ -233,12 +233,12 @@ fmt.Fprintln(sshStderr, ansiec.Blue+"POK"+ansiec.Reset, refName) var newMRLocalID int if packPass.userID != 0 { - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "INSERT INTO merge_requests (repo_id, creator, source_ref, status) VALUES ($1, $2, $3, 'open') RETURNING repo_local_id", packPass.repoID, packPass.userID, strings.TrimPrefix(refName, "refs/heads/"), ).Scan(&newMRLocalID) } else { - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "INSERT INTO merge_requests (repo_id, source_ref, status) VALUES ($1, $2, 'open') RETURNING repo_local_id", packPass.repoID, strings.TrimPrefix(refName, "refs/heads/"), ).Scan(&newMRLocalID) @@ -259,7 +259,7 @@ } else { // Existing contrib branch var existingMRUser int var isAncestor bool - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "SELECT COALESCE(creator, 0) FROM merge_requests WHERE source_ref = $1 AND repo_id = $2", strings.TrimPrefix(refName, "refs/heads/"), packPass.repoID, ).Scan(&existingMRUser) diff --git a/git_hooks_handle_other.go b/git_hooks_handle_other.go index 6d5b08daeace239046700415bc590c38f4c1ffef..ed75e7ae2924e2a3ad86ece7f504867a0bd00000 100644 --- a/git_hooks_handle_other.go +++ b/git_hooks_handle_other.go @@ -211,12 +211,12 @@ fmt.Fprintln(sshStderr, ansiec.Blue+"POK"+ansiec.Reset, refName) var newMRLocalID int if packPass.userID != 0 { - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "INSERT INTO merge_requests (repo_id, creator, source_ref, status) VALUES ($1, $2, $3, 'open') RETURNING repo_local_id", packPass.repoID, packPass.userID, strings.TrimPrefix(refName, "refs/heads/"), ).Scan(&newMRLocalID) } else { - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "INSERT INTO merge_requests (repo_id, source_ref, status) VALUES ($1, $2, 'open') RETURNING repo_local_id", packPass.repoID, strings.TrimPrefix(refName, "refs/heads/"), ).Scan(&newMRLocalID) @@ -237,7 +237,7 @@ } else { // Existing contrib branch var existingMRUser int var isAncestor bool - err = database.QueryRow(ctx, + err = s.database.QueryRow(ctx, "SELECT COALESCE(creator, 0) FROM merge_requests WHERE source_ref = $1 AND repo_id = $2", strings.TrimPrefix(refName, "refs/heads/"), packPass.repoID, ).Scan(&existingMRUser) diff --git a/git_misc.go b/git_misc.go index 8e72d0cebc476265ce54fc1e409655610a3d6bfd..8dda01cd885ec7ea5a39096d47cf5e4809759ba7 100644 --- a/git_misc.go +++ b/git_misc.go @@ -22,8 +22,8 @@ // // TODO: This should be deprecated in favor of doing it in the relevant // request/router context in the future, as it cannot cover the nuance of // fields needed. -func openRepo(ctx context.Context, groupPath []string, repoName string) (repo *git.Repository, description string, repoID int, fsPath string, err error) { - err = database.QueryRow(ctx, ` +func (s *server) openRepo(ctx context.Context, groupPath []string, repoName string) (repo *git.Repository, description string, repoID int, fsPath string, err error) { + err = s.database.QueryRow(ctx, ` WITH RECURSIVE group_path_cte AS ( -- Start: match the first name in the path where parent_group IS NULL SELECT diff --git a/http_auth.go b/http_auth.go index 03b7e2b24bf08cb03211f305a409f8292f80d739..5f0dc66cb2d8a55bf78aa6e079dd016fdf78609a 100644 --- a/http_auth.go +++ b/http_auth.go @@ -9,14 +9,14 @@ ) // getUserFromRequest returns the user ID and username associated with the // session cookie in a given [http.Request]. -func getUserFromRequest(request *http.Request) (id int, username string, err error) { +func (s *server) getUserFromRequest(request *http.Request) (id int, username string, err error) { var sessionCookie *http.Cookie if sessionCookie, err = request.Cookie("session"); err != nil { return } - err = database.QueryRow( + err = s.database.QueryRow( request.Context(), "SELECT user_id, COALESCE(username, '') FROM users u JOIN sessions s ON u.id = s.user_id WHERE s.session_id = $1;", sessionCookie.Value, diff --git a/http_handle_group_index.go b/http_handle_group_index.go index 16120a89e91d2ee1711aed6e601765019e425cc0..46f1f6a0f68937da5fc8e541aac721aa4897dc79 100644 --- a/http_handle_group_index.go +++ b/http_handle_group_index.go @@ -28,7 +28,7 @@ groupPath = params["group_path"].([]string) // The group itself - err = database.QueryRow(request.Context(), ` + err = s.database.QueryRow(request.Context(), ` WITH RECURSIVE group_path_cte AS ( SELECT id, @@ -69,7 +69,7 @@ } // ACL var count int - err = database.QueryRow(request.Context(), ` + err = s.database.QueryRow(request.Context(), ` SELECT COUNT(*) FROM user_group_roles WHERE user_id = $1 @@ -96,7 +96,7 @@ return } var newRepoID int - err := database.QueryRow( + err := s.database.QueryRow( request.Context(), `INSERT INTO repos (name, description, group_id, contrib_requirements) VALUES ($1, $2, $3, $4) @@ -113,7 +113,7 @@ } filePath := filepath.Join(s.config.Git.RepoDir, strconv.Itoa(newRepoID)+".git") - _, err = database.Exec( + _, err = s.database.Exec( request.Context(), `UPDATE repos SET filesystem_path = $1 @@ -137,7 +137,7 @@ } // Repos var rows pgx.Rows - rows, err = database.Query(request.Context(), ` + rows, err = s.database.Query(request.Context(), ` SELECT name, COALESCE(description, '') FROM repos WHERE group_id = $1 @@ -162,7 +162,7 @@ return } // Subgroups - rows, err = database.Query(request.Context(), ` + rows, err = s.database.Query(request.Context(), ` SELECT name, COALESCE(description, '') FROM groups WHERE parent_group = $1 diff --git a/http_handle_login.go b/http_handle_login.go index ea1dbae6454d8ee7026f76170cbcac798393bdf7..10bfdcd46a4a5f4325ae8fda55452ad2a551d211 100644 --- a/http_handle_login.go +++ b/http_handle_login.go @@ -35,7 +35,7 @@ username = request.PostFormValue("username") password = request.PostFormValue("password") - err = database.QueryRow(request.Context(), + err = s.database.QueryRow(request.Context(), "SELECT id, COALESCE(password, '') FROM users WHERE username = $1", username, ).Scan(&userID, &passwordHash) @@ -85,7 +85,7 @@ } //exhaustruct:ignore http.SetCookie(writer, &cookie) - _, err = database.Exec(request.Context(), "INSERT INTO sessions (user_id, session_id) VALUES ($1, $2)", userID, cookieValue) + _, err = s.database.Exec(request.Context(), "INSERT INTO sessions (user_id, session_id) VALUES ($1, $2)", userID, cookieValue) if err != nil { errorPage500(writer, params, "Error inserting session: "+err.Error()) return diff --git a/http_handle_repo_contrib_index.go b/http_handle_repo_contrib_index.go index ee7b9561fc5424717f56d7a1a1eb5b926a3cf55d..e0c8478962c42772b0d36f3bdd7432add2fa203e 100644 --- a/http_handle_repo_contrib_index.go +++ b/http_handle_repo_contrib_index.go @@ -18,12 +18,12 @@ Status string } // httpHandleRepoContribIndex provides an index to merge requests of a repo. -func httpHandleRepoContribIndex(writer http.ResponseWriter, request *http.Request, params map[string]any) { +func (s *server) httpHandleRepoContribIndex(writer http.ResponseWriter, request *http.Request, params map[string]any) { var rows pgx.Rows var result []idTitleStatus var err error - if rows, err = database.Query(request.Context(), + if rows, err = s.database.Query(request.Context(), "SELECT repo_local_id, COALESCE(title, 'Untitled'), status FROM merge_requests WHERE repo_id = $1", params["repo_id"], ); err != nil { diff --git a/http_handle_repo_contrib_one.go b/http_handle_repo_contrib_one.go index dcd0e0df99f285dc684fa3a5d3b665fe23ee2146..0df749148527c1bb02893c3b8acc7629349ed35f 100644 --- a/http_handle_repo_contrib_one.go +++ b/http_handle_repo_contrib_one.go @@ -14,7 +14,7 @@ ) // httpHandleRepoContribOne provides an interface to each merge request of a // repo. -func httpHandleRepoContribOne(writer http.ResponseWriter, request *http.Request, params map[string]any) { +func (s *server) httpHandleRepoContribOne(writer http.ResponseWriter, request *http.Request, params map[string]any) { var mrIDStr string var mrIDInt int var err error @@ -33,7 +33,7 @@ return } mrIDInt = int(mrIDInt64) - if err = database.QueryRow(request.Context(), + if err = s.database.QueryRow(request.Context(), "SELECT COALESCE(title, ''), status, source_ref, COALESCE(destination_branch, '') FROM merge_requests WHERE repo_id = $1 AND repo_local_id = $2", params["repo_id"], mrIDInt, ).Scan(&title, &status, &srcRefStr, &dstBranchStr); err != nil { diff --git a/http_handle_repo_info.go b/http_handle_repo_info.go index 3f1787eae2137b5d43a593d9ffc21ab88c059acf..b7b743800dc6a83c38ac0406008feaaee86f782b 100644 --- a/http_handle_repo_info.go +++ b/http_handle_repo_info.go @@ -16,12 +16,12 @@ // httpHandleRepoInfo provides advertised refs of a repo for use in Git's Smart // HTTP protocol. // // TODO: Reject access from web browsers. -func httpHandleRepoInfo(writer http.ResponseWriter, request *http.Request, params map[string]any) (err error) { +func (s *server) httpHandleRepoInfo(writer http.ResponseWriter, request *http.Request, params map[string]any) (err error) { groupPath := params["group_path"].([]string) repoName := params["repo_name"].(string) var repoPath string - if err := database.QueryRow(request.Context(), ` + if err := s.database.QueryRow(request.Context(), ` WITH RECURSIVE group_path_cte AS ( -- Start: match the first name in the path where parent_group IS NULL SELECT diff --git a/http_handle_repo_upload_pack.go b/http_handle_repo_upload_pack.go index 3d9170cefe91ffabb06a3d9455dc8afcc7dfb356..a6580a7720a2969944774d9ccdb3a06fc7e68262 100644 --- a/http_handle_repo_upload_pack.go +++ b/http_handle_repo_upload_pack.go @@ -24,7 +24,7 @@ var cmd *exec.Cmd groupPath, repoName = params["group_path"].([]string), params["repo_name"].(string) - if err := database.QueryRow(request.Context(), ` + if err := s.database.QueryRow(request.Context(), ` WITH RECURSIVE group_path_cte AS ( -- Start: match the first name in the path where parent_group IS NULL SELECT diff --git a/http_server.go b/http_server.go index 5c78533e524bb8a70d22a4dfe7bedbd5489f5888..ae822417c9a742141cb506120d2d215d94080d18 100644 --- a/http_server.go +++ b/http_server.go @@ -52,7 +52,7 @@ params["url_segments"] = segments params["dir_mode"] = dirMode params["global"] = globalData var userID int // 0 for none - userID, params["username"], err = getUserFromRequest(request) + userID, params["username"], err = s.getUserFromRequest(request) params["user_id"] = userID if err != nil && !errors.Is(err, http.ErrNoCookie) && !errors.Is(err, pgx.ErrNoRows) { errorPage500(writer, params, "Error getting user info from request: "+err.Error()) @@ -152,7 +152,7 @@ if len(segments) > sepIndex+3 { switch segments[sepIndex+3] { case "info": - if err = httpHandleRepoInfo(writer, request, params); err != nil { + if err = s.httpHandleRepoInfo(writer, request, params); err != nil { errorPage500(writer, params, err.Error()) } return @@ -173,7 +173,7 @@ return } } - if params["repo"], params["repo_description"], params["repo_id"], _, err = openRepo(request.Context(), groupPath, moduleName); err != nil { + if params["repo"], params["repo_description"], params["repo_id"], _, err = s.openRepo(request.Context(), groupPath, moduleName); err != nil { errorPage500(writer, params, "Error opening repo: "+err.Error()) return } @@ -256,10 +256,10 @@ return } switch len(segments) { case sepIndex + 4: - httpHandleRepoContribIndex(writer, request, params) + s.httpHandleRepoContribIndex(writer, request, params) case sepIndex + 5: params["mr_id"] = segments[sepIndex+4] - httpHandleRepoContribOne(writer, request, params) + s.httpHandleRepoContribOne(writer, request, params) default: errorPage400(writer, params, "Too many parameters") } diff --git a/lmtp_handle_patch.go b/lmtp_handle_patch.go index 45d146a74d093e72fab244e3e9e456aaff8f8813..ab846aa88bbf1240d21b57f59c85a728902f999e 100644 --- a/lmtp_handle_patch.go +++ b/lmtp_handle_patch.go @@ -19,7 +19,7 @@ "github.com/go-git/go-git/v5" "go.lindenii.runxiyu.org/forge/misc" ) -func lmtpHandlePatch(session *lmtpSession, groupPath []string, repoName string, mbox io.Reader) (err error) { +func (s *server) lmtpHandlePatch(session *lmtpSession, groupPath []string, repoName string, mbox io.Reader) (err error) { var diffFiles []*gitdiff.File var preamble string if diffFiles, preamble, err = gitdiff.Parse(mbox); err != nil { @@ -33,7 +33,7 @@ } var repo *git.Repository var fsPath string - repo, _, _, fsPath, err = openRepo(session.ctx, groupPath, repoName) + repo, _, _, fsPath, err = s.openRepo(session.ctx, groupPath, repoName) if err != nil { return fmt.Errorf("failed to open repo: %w", err) } diff --git a/lmtp_server.go b/lmtp_server.go index e97ca55e37a2f5c704b1049ba7d02e63b2d869da..8191766f8f9d02038c4ffe9c393ca07ef903632d 100644 --- a/lmtp_server.go +++ b/lmtp_server.go @@ -177,7 +177,7 @@ moduleType := segments[sepIndex+1] moduleName := segments[sepIndex+2] switch moduleType { case "repos": - err = lmtpHandlePatch(session, groupPath, moduleName, &mbox) + err = session.s.lmtpHandlePatch(session, groupPath, moduleName, &mbox) if err != nil { slog.Error("error handling patch", "error", err) goto end diff --git a/server.go b/server.go index 8f35913ab8b538743d699ed26227890b7181157c..1113740355878c7bcbfc56cf84ee1b0de4e4b554 100644 --- a/server.go +++ b/server.go @@ -1,5 +1,12 @@ package main +import "github.com/jackc/pgx/v5/pgxpool" + type server struct { config Config + + // database serves as the primary database handle for this entire application. + // Transactions or single reads may be used from it. A [pgxpool.Pool] is + // necessary to safely use pgx concurrently; pgx.Conn, etc. are insufficient. + database *pgxpool.Pool } diff --git a/ssh_handle_receive_pack.go b/ssh_handle_receive_pack.go index ed7ef40476b031b8c92fc1a44f44e65fadf4ce8e..317609f8c260da152b7b0df5330dbe0b7b5dd396 100644 --- a/ssh_handle_receive_pack.go +++ b/ssh_handle_receive_pack.go @@ -76,7 +76,7 @@ if pubkey == "" { return errors.New("you need to have an SSH public key to push to this repo") } if userType == "" { - userID, err = addUserSSH(session.Context(), pubkey) + userID, err = s.addUserSSH(session.Context(), pubkey) if err != nil { return err } diff --git a/users.go b/users.go index f0dabce3de9018098bacd3c30aa5598486dee2e2..1b31f3a1f9b66d61a9d9f76dd6d47ee323277168 100644 --- a/users.go +++ b/users.go @@ -12,10 +12,10 @@ // addUserSSH adds a new user solely based on their SSH public key. // // TODO: Audit all users of this function. -func addUserSSH(ctx context.Context, pubkey string) (userID int, err error) { +func (s *server) addUserSSH(ctx context.Context, pubkey string) (userID int, err error) { var txn pgx.Tx - if txn, err = database.Begin(ctx); err != nil { + if txn, err = s.database.Begin(ctx); err != nil { return } defer func() { -- 2.48.1