Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Set secure cookie if reverse proxy indicates HTTPS
package main import ( "context" "fmt" "html/template"
"mime"
"net/http"
)
const internalTokenScope = "_sinwon"
type contextKey string
const (
contextKeyDB = "db"
contextKeyTemplate = "template"
contextKeyLoginToken = "login-token"
)
func dbFromContext(ctx context.Context) *DB {
return ctx.Value(contextKeyDB).(*DB)
}
func templateFromContext(ctx context.Context) *template.Template {
return ctx.Value(contextKeyTemplate).(*template.Template)
}
func loginTokenFromContext(ctx context.Context) *AccessToken {
v := ctx.Value(contextKeyLoginToken)
if v == nil {
return nil
}
return v.(*AccessToken)
}
func newBaseContext(db *DB, tpl *template.Template) context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, contextKeyDB, db)
ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
return ctx
}
func setLoginTokenCookie(w http.ResponseWriter, token *AccessToken, secret string) {
func setLoginTokenCookie(w http.ResponseWriter, req *http.Request, token *AccessToken, secret string) {
http.SetCookie(w, &http.Cookie{
Name: "sinwon-token",
Value: MarshalSecret(token.ID, secret),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
// TODO: Secure
Secure: isForwardedHTTPS(req),
}) }
func unsetLoginTokenCookie(w http.ResponseWriter) {
func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "sinwon-token",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: isForwardedHTTPS(req),
MaxAge: -1, }) }
func isForwardedHTTPS(req *http.Request) bool {
if forwarded := req.Header.Get("Forwarded"); forwarded != "" {
_, params, _ := mime.ParseMediaType("_; " + forwarded)
return params["proto"] == "https"
}
if forwardedProto := req.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
return forwardedProto == "https"
}
return false
}
func loginTokenMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cookie, _ := req.Cookie("sinwon-token")
if cookie == nil {
next.ServeHTTP(w, req)
return
}
ctx := req.Context()
db := dbFromContext(ctx)
tokenID, tokenSecret, _ := UnmarshalSecret[*AccessToken](cookie.Value)
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(tokenSecret)) {
unsetLoginTokenCookie(w)
unsetLoginTokenCookie(w, req)
next.ServeHTTP(w, req)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if token.Scope != internalTokenScope {
http.Error(w, "Invalid login token scope", http.StatusForbidden)
return
}
if token.User == 0 {
panic("login token with zero user ID")
}
ctx = context.WithValue(ctx, contextKeyLoginToken, token)
req = req.WithContext(ctx)
next.ServeHTTP(w, req)
})
}
package main
import (
"fmt"
"log"
"net/http"
"net/url"
"github.com/go-chi/chi/v5"
)
func index(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
clients, err := db.ListClients(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
data := struct {
Clients []Client
Me ID[*User]
}{
Clients: clients,
Me: loginToken.User,
}
if err := tpl.ExecuteTemplate(w, "index.html", &data); err != nil {
panic(err)
}
}
func login(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
rawRedirectURI := q.Get("redirect_uri")
if rawRedirectURI == "" {
rawRedirectURI = "/"
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if loginTokenFromContext(ctx) != nil {
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
return
}
username := req.PostFormValue("username")
password := req.PostFormValue("password")
if username == "" {
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
user, err := db.FetchUserByUsername(ctx, username)
if err != nil && err != errNoDBRows {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
if err == nil {
err = user.VerifyPassword(password)
}
if err != nil {
log.Printf("login failed for user %q: %v", username, err)
// TODO: show error message
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
token := AccessToken{
User: user.ID,
Scope: internalTokenScope,
}
secret, err := token.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate access token: %v", err))
return
}
if err := db.CreateAccessToken(ctx, &token); err != nil {
httpError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
setLoginTokenCookie(w, &token, secret)
setLoginTokenCookie(w, req, &token, secret)
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}
func logout(w http.ResponseWriter, req *http.Request) {
unsetLoginTokenCookie(w)
unsetLoginTokenCookie(w, req)
http.Redirect(w, req, "/login", http.StatusFound)
}
func updateUser(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
user := new(User)
if idStr := chi.URLParam(req, "id"); idStr != "" {
id, err := ParseID[*User](idStr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
user, err = db.FetchUser(ctx, id)
if err != nil {
httpError(w, err)
return
}
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if user.ID != 0 && loginToken.User != user.ID {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
username := req.PostFormValue("username")
password := req.PostFormValue("password")
if username == "" {
if err := tpl.ExecuteTemplate(w, "update-user.html", user); err != nil {
panic(err)
}
return
}
user.Username = username
if password != "" {
if err := user.SetPassword(password); err != nil {
httpError(w, err)
return
}
}
if err := db.StoreUser(ctx, user); err != nil {
httpError(w, err)
return
}
http.Redirect(w, req, "/", http.StatusFound)
}