From f828acac387aacadd2884837402b0e32b2368470 Mon Sep 17 00:00:00 2001 From: Runxi Yu Date: Wed, 12 Feb 2025 19:16:41 +0800 Subject: [PATCH] *.go: Use the database for repo info, and fix ssh cloning repo --- git_misc.go | 17 +++++++---------- handle_group_index.go | 29 ++++++++++++++++++----------- handle_repo_commit.go | 2 +- handle_repo_index.go | 2 +- handle_repo_log.go | 2 +- handle_repo_raw.go | 2 +- handle_repo_tree.go | 2 +- router_ssh.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++ ssh.go | 10 ++++++++-- url_misc.go | 5 ++++- diff --git a/git_misc.go b/git_misc.go index 882c63193c9085fe272eef6e39fc1aeeef89de4b..2d4c4d3036e9addd061f848a913722a686dd0593 100644 --- a/git_misc.go +++ b/git_misc.go @@ -1,9 +1,9 @@ package main import ( + "context" "errors" "io" - "path/filepath" "strings" "github.com/go-git/go-git/v5" @@ -19,16 +19,13 @@ err_getting_patch_of_commit = errors.New("Error getting patch of commit") err_getting_parent_commit_object = errors.New("Error getting parent commit object") ) -func open_git_repo(group_name, repo_name string) (*git.Repository, error) { - group_name, group_name_ok := misc.Sanitize_path(group_name) - if !group_name_ok { - return nil, err_unsafe_path +func open_git_repo(ctx context.Context, group_name, repo_name string) (*git.Repository, error) { + var fs_path string + err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, repo_name).Scan(&fs_path) + if err != nil { + return nil, err } - repo_name, repo_name_ok := misc.Sanitize_path(repo_name) - if !repo_name_ok { - return nil, err_unsafe_path - } - return git.PlainOpen(filepath.Join(config.Git.Root, group_name, repo_name+".git")) + return git.PlainOpen(fs_path) } type display_git_tree_entry_t struct { diff --git a/handle_group_index.go b/handle_group_index.go index bc7a7f45d04ff701d81975e4a90d99ec85b83f38..0bb4a5783e790c8a1b6f515fac0c6b18f20e01c4 100644 --- a/handle_group_index.go +++ b/handle_group_index.go @@ -2,29 +2,36 @@ package main import ( "net/http" - "os" - "path/filepath" - "strings" ) func handle_group_repos(w http.ResponseWriter, r *http.Request, params map[string]string) { data := make(map[string]any) group_name := params["group_name"] data["group_name"] = group_name - entries, err := os.ReadDir(filepath.Join(config.Git.Root, group_name)) + + var names []string + rows, err := database.Query(r.Context(), "SELECT r.name FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1;", group_name) if err != nil { - _, _ = w.Write([]byte("Error listing repos: " + err.Error())) + _, _ = w.Write([]byte("Error getting groups: " + err.Error())) return } + defer rows.Close() - repos := []string{} - for _, entry := range entries { - this_name := entry.Name() - if strings.HasSuffix(this_name, ".git") { - repos = append(repos, strings.TrimSuffix(this_name, ".git")) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + _, _ = w.Write([]byte("Error scanning row: " + err.Error())) + return } + names = append(names, name) } - data["repos"] = repos + + if err := rows.Err(); err != nil { + _, _ = w.Write([]byte("Error iterating over rows: " + err.Error())) + return + } + + data["repos"] = names err = templates.ExecuteTemplate(w, "group_repos", data) if err != nil { diff --git a/handle_repo_commit.go b/handle_repo_commit.go index aefd58beab7c0c709da52f8b5d19da5f933f7754..b567baa78e2895f25ccedb2bad233062c7328cea 100644 --- a/handle_repo_commit.go +++ b/handle_repo_commit.go @@ -20,7 +20,7 @@ func handle_repo_commit(w http.ResponseWriter, r *http.Request, params map[string]string) { data := make(map[string]any) group_name, repo_name, commit_id_specified_string := params["group_name"], params["repo_name"], params["commit_id"] data["group_name"], data["repo_name"] = group_name, repo_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_index.go b/handle_repo_index.go index 6372b03c40d492f8ce7f7f0ce89f964adf4d85c8..c0bef4a2a55b9b3dbb700e23d29faa731e651c6d 100644 --- a/handle_repo_index.go +++ b/handle_repo_index.go @@ -8,7 +8,7 @@ func handle_repo_index(w http.ResponseWriter, r *http.Request, params map[string]string) { data := make(map[string]any) group_name, repo_name := params["group_name"], params["repo_name"] data["group_name"], data["repo_name"] = group_name, repo_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_log.go b/handle_repo_log.go index eff58590c7622078e2aff2f5f1b75e582dce163f..1c32862ec658589c3fc113331ff3ce3098b24a47 100644 --- a/handle_repo_log.go +++ b/handle_repo_log.go @@ -11,7 +11,7 @@ func handle_repo_log(w http.ResponseWriter, r *http.Request, params map[string]string) { data := make(map[string]any) group_name, repo_name, ref_name := params["group_name"], params["repo_name"], params["ref"] data["group_name"], data["repo_name"], data["ref"] = group_name, repo_name, ref_name - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_raw.go b/handle_repo_raw.go index d335f6a770a19800a74e6eba03ad7ca07c1b4120..4cf7d1a1e1db0fb5eee39afdf20ecd71d132bf85 100644 --- a/handle_repo_raw.go +++ b/handle_repo_raw.go @@ -26,7 +26,7 @@ } data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/handle_repo_tree.go b/handle_repo_tree.go index f95e9452f213a6beed3a6cdf1ec135e59cb36d0a..8076ed6f75cdefb4a0a72cb1282fa3b26354e8d3 100644 --- a/handle_repo_tree.go +++ b/handle_repo_tree.go @@ -28,7 +28,7 @@ return } } data["ref_type"], data["ref"], data["group_name"], data["repo_name"], data["path_spec"] = ref_type, ref_name, group_name, repo_name, path_spec - repo, err := open_git_repo(group_name, repo_name) + repo, err := open_git_repo(r.Context(), group_name, repo_name) if err != nil { _, _ = w.Write([]byte("Error opening repo: " + err.Error())) return diff --git a/router_ssh.go b/router_ssh.go new file mode 100644 index 0000000000000000000000000000000000000000..6b5280be5a8273cd91e97fd7dd5277632ae80d30 --- /dev/null +++ b/router_ssh.go @@ -0,0 +1,56 @@ +package main + +import ( + "context" + "errors" + "net/url" + "strings" +) + +var err_ssh_illegal_endpoint = errors.New("Illegal endpoint during SSH access") + +func get_repo_path_from_ssh_path(ctx context.Context, ssh_path string) (repo_path string, err error) { + segments := strings.Split(strings.TrimPrefix(ssh_path, "/"), "/") + + for i, segment := range segments { + var err error + segments[i], err = url.QueryUnescape(segment) + if err != nil { + return "", err + } + } + + if segments[0] == ":" { + return "", err_ssh_illegal_endpoint + } + + separator_index := -1 + for i, part := range segments { + if part == ":" { + separator_index = i + break + } + } + if segments[len(segments)-1] == "" { + segments = segments[:len(segments)-1] + } + + switch { + case separator_index == -1: + return "", err_ssh_illegal_endpoint + case len(segments) <= separator_index+2: + return "", err_ssh_illegal_endpoint + } + + group_name := segments[0] + module_type := segments[separator_index+1] + module_name := segments[separator_index+2] + switch module_type { + case "repos": + var fs_path string + err := database.QueryRow(ctx, "SELECT r.filesystem_path FROM repos r JOIN groups g ON r.group_id = g.id WHERE g.name = $1 AND r.name = $2;", group_name, module_name).Scan(&fs_path) + return fs_path, err + default: + return "", err_ssh_illegal_endpoint + } +} diff --git a/ssh.go b/ssh.go index e1b9ff19a70ce4ace164cb1d1b4f7c62f45dcadf..4d49fc982bcbefcaf292b578c9895116b9da1580 100644 --- a/ssh.go +++ b/ssh.go @@ -43,12 +43,18 @@ fmt.Fprintln(session.Stderr(), "Unsupported command") return } - proc := exec.CommandContext(session.Context(), cmd[0], "/home/runxiyu/git/forge.git") + fs_path, err := get_repo_path_from_ssh_path(session.Context(), cmd[1]) + if err != nil { + fmt.Fprintln(session.Stderr(), "Error while getting repo path:", err) + return + } + + proc := exec.CommandContext(session.Context(), cmd[0], fs_path) proc.Stdin = session proc.Stdout = session proc.Stderr = session.Stderr() - err := proc.Start() + err = proc.Start() if err != nil { fmt.Fprintln(session.Stderr(), "Error while starting process:", err) return diff --git a/url_misc.go b/url_misc.go index e4bfd92191e32f8e9e805b5681a187bb69723aa2..7dc0ad57def05f2b49dcb83e73a67be08eef144d 100644 --- a/url_misc.go +++ b/url_misc.go @@ -50,7 +50,10 @@ segments = strings.Split(strings.TrimPrefix(path, "/"), "/") for i, segment := range segments { - segments[i], _ = url.QueryUnescape(segment) + segments[i], err = url.QueryUnescape(segment) + if err != nil { + return nil, nil, misc.Wrap_one_error(err_bad_request, err) + } } params, err = url.ParseQuery(params_string) -- 2.48.1