Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Split HTTP handlers in separate files
package main
import (
"net/http"
)
func createClient(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
client, clientSecret, err := NewClient(loginToken.User)
if err != nil {
httpError(w, err)
return
}
if err := db.StoreClient(ctx, client); err != nil {
httpError(w, err)
return
}
data := struct {
ClientID string
ClientSecret string
}{
ClientID: client.ClientID,
ClientSecret: clientSecret,
}
if err := tpl.ExecuteTemplate(w, "client-secret.html", &data); err != nil {
panic(err)
}
}
package main import ( "context" "embed"
"encoding/json" "errors"
"flag"
"fmt"
"html/template"
"io"
"log"
"mime"
"net" "net/http"
"net/url" "strings" "time"
"git.sr.ht/~emersion/go-oauth2"
"github.com/go-chi/chi/v5"
)
var (
//go:embed template
templateFS embed.FS
//go:embed static
staticFS embed.FS
)
func main() {
var listenAddr string
flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
flag.Parse()
tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))
db, err := openDB("sinwon.db")
if err != nil {
log.Fatalf("Failed to open DB: %v", err)
}
mux := chi.NewRouter()
mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
mux.Get("/", func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
clients, err := db.ListClients(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
panic(err)
}
})
mux.Post("/client/new", func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
client, clientSecret, err := NewClient(loginToken.User)
if err != nil {
httpError(w, err)
return
}
if err := db.StoreClient(ctx, client); err != nil {
httpError(w, err)
return
}
data := struct {
ClientID string
ClientSecret string
}{
ClientID: client.ClientID,
ClientSecret: clientSecret,
}
if err := tpl.ExecuteTemplate(w, "client-secret.html", &data); err != nil {
panic(err)
}
})
mux.HandleFunc("/login", func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
q := req.URL.Query()
rawRedirectURI := q.Get("redirect_uri")
if rawRedirectURI == "" {
rawRedirectURI = "/"
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if loginTokenFromContext(ctx) != nil {
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
return
}
username := req.PostFormValue("username")
password := req.PostFormValue("password")
if username == "" {
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
user, err := db.FetchUser(ctx, username)
if err != nil && err != errNoDBRows {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
if err == nil {
err = user.VerifyPassword(password)
}
if err != nil {
log.Printf("login failed for user %q: %v", username, err)
// TODO: show error message
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
token := AccessToken{
User: user.ID,
Scope: internalTokenScope,
}
secret, err := token.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate access token: %v", err))
return
}
if err := db.CreateAccessToken(ctx, &token); err != nil {
httpError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
setLoginTokenCookie(w, &token, secret)
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
})
mux.HandleFunc("/authorize", func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
q := req.URL.Query()
respType := oauth2.ResponseType(q.Get("response_type"))
clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
return
}
client, err := db.FetchClient(ctx, clientID)
if err == errNoDBRows {
http.Error(w, "Invalid client ID", http.StatusForbidden)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
// TODO: validate redirect URI with client
// TODO: make redirect URI optional
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if respType != oauth2.ResponseTypeCode {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType,
})
return
}
// TODO: add support for scope
if scope != "" {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
})
return
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
q := make(url.Values)
q.Set("redirect_uri", req.URL.String())
u := url.URL{
Path: "/login",
RawQuery: q.Encode(),
}
http.Redirect(w, req, u.String(), http.StatusFound)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["authorize"]; !ok {
if err := tpl.ExecuteTemplate(w, "authorize.html", nil); err != nil {
panic(err)
}
return
}
authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
if err != nil {
httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
return
}
if err := db.CreateAuthCode(ctx, authCode); err != nil {
httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
return
}
code := MarshalSecret(authCode.ID, secret)
values := make(url.Values)
values.Set("code", code)
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
})
mux.Post("/token", func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" && authClientID == "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing client ID",
})
return
} else if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Client ID in request body doesn't match Authorization header field",
})
return
}
client, err := db.FetchClient(ctx, clientID)
if err == errNoDBRows {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if client.ClientSecretHash != nil {
if !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
}
if grantType != oauth2.GrantTypeAuthorizationCode {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
authCode, err := db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
if scope != authCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
// TODO: check redirect_uri
token, secret, err := NewAccessTokenFromAuthCode(authCode)
if err != nil {
oauthError(w, err)
return
}
if err := db.CreateAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(&oauth2.TokenResp{
AccessToken: secret,
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: time.Until(token.ExpiresAt),
Scope: strings.Split(token.Scope, " "),
})
})
mux.Get("/", index)
mux.Post("/client/new", createClient)
mux.HandleFunc("/login", login)
mux.HandleFunc("/authorize", authorize)
mux.Post("/token", exchangeToken)
server := http.Server{
Addr: listenAddr,
Handler: loginTokenMiddleware(mux),
BaseContext: func(net.Listener) context.Context {
return newBaseContext(db, tpl)
},
}
log.Printf("OAuth server listening on %v", server.Addr)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to listen and serve: %v", err)
}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
ct := req.Header.Get("Content-Type")
if ct != "" {
mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("malformed Content-Type header field")
} else if mimeType != "application/x-www-form-urlencoded" {
return nil, fmt.Errorf("unsupported request content type")
}
}
r := io.LimitReader(req.Body, 10<<20)
b, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
return values, nil
}
func httpError(w http.ResponseWriter, err error) {
log.Print(err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
func oauthError(w http.ResponseWriter, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
statusCode := http.StatusInternalServerError
switch oauthErr.Code {
case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
statusCode = http.StatusBadRequest
case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
statusCode = http.StatusForbidden
case oauth2.ErrorCodeTemporarilyUnavailable:
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
q := redirectURI.Query()
for k, v := range values {
q[k] = v
}
u := *redirectURI
u.RawQuery = q.Encode()
http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
values := make(url.Values)
values.Set("error", string(oauthErr.Code))
if oauthErr.Description != "" {
values.Set("error_description", oauthErr.Description)
}
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"strings"
"time"
"git.sr.ht/~emersion/go-oauth2"
)
func authorize(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
respType := oauth2.ResponseType(q.Get("response_type"))
clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
return
}
client, err := db.FetchClient(ctx, clientID)
if err == errNoDBRows {
http.Error(w, "Invalid client ID", http.StatusForbidden)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
// TODO: validate redirect URI with client
// TODO: make redirect URI optional
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if respType != oauth2.ResponseTypeCode {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType,
})
return
}
// TODO: add support for scope
if scope != "" {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
})
return
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
q := make(url.Values)
q.Set("redirect_uri", req.URL.String())
u := url.URL{
Path: "/login",
RawQuery: q.Encode(),
}
http.Redirect(w, req, u.String(), http.StatusFound)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["authorize"]; !ok {
if err := tpl.ExecuteTemplate(w, "authorize.html", nil); err != nil {
panic(err)
}
return
}
authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
if err != nil {
httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
return
}
if err := db.CreateAuthCode(ctx, authCode); err != nil {
httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
return
}
code := MarshalSecret(authCode.ID, secret)
values := make(url.Values)
values.Set("code", code)
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" && authClientID == "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing client ID",
})
return
} else if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Client ID in request body doesn't match Authorization header field",
})
return
}
client, err := db.FetchClient(ctx, clientID)
if err == errNoDBRows {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if client.ClientSecretHash != nil {
if !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
}
if grantType != oauth2.GrantTypeAuthorizationCode {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
authCode, err := db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
if scope != authCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
// TODO: check redirect_uri
token, secret, err := NewAccessTokenFromAuthCode(authCode)
if err != nil {
oauthError(w, err)
return
}
if err := db.CreateAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(&oauth2.TokenResp{
AccessToken: secret,
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: time.Until(token.ExpiresAt),
Scope: strings.Split(token.Scope, " "),
})
}
func parseRequestBody(req *http.Request) (url.Values, error) {
ct := req.Header.Get("Content-Type")
if ct != "" {
mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("malformed Content-Type header field")
} else if mimeType != "application/x-www-form-urlencoded" {
return nil, fmt.Errorf("unsupported request content type")
}
}
r := io.LimitReader(req.Body, 10<<20)
b, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
statusCode := http.StatusInternalServerError
switch oauthErr.Code {
case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
statusCode = http.StatusBadRequest
case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
statusCode = http.StatusForbidden
case oauth2.ErrorCodeTemporarilyUnavailable:
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
q := redirectURI.Query()
for k, v := range values {
q[k] = v
}
u := *redirectURI
u.RawQuery = q.Encode()
http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
values := make(url.Values)
values.Set("error", string(oauthErr.Code))
if oauthErr.Description != "" {
values.Set("error_description", oauthErr.Description)
}
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
package main
import (
"fmt"
"log"
"net/http"
"net/url"
)
func index(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
http.Redirect(w, req, "/login", http.StatusFound)
return
}
clients, err := db.ListClients(ctx, loginToken.User)
if err != nil {
httpError(w, err)
return
}
if err := tpl.ExecuteTemplate(w, "index.html", clients); err != nil {
panic(err)
}
}
func login(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
rawRedirectURI := q.Get("redirect_uri")
if rawRedirectURI == "" {
rawRedirectURI = "/"
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil || redirectURI.Scheme != "" || redirectURI.Opaque != "" || redirectURI.User != nil || redirectURI.Host != "" {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if loginTokenFromContext(ctx) != nil {
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
return
}
username := req.PostFormValue("username")
password := req.PostFormValue("password")
if username == "" {
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
user, err := db.FetchUser(ctx, username)
if err != nil && err != errNoDBRows {
httpError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
if err == nil {
err = user.VerifyPassword(password)
}
if err != nil {
log.Printf("login failed for user %q: %v", username, err)
// TODO: show error message
if err := tpl.ExecuteTemplate(w, "login.html", nil); err != nil {
panic(err)
}
return
}
token := AccessToken{
User: user.ID,
Scope: internalTokenScope,
}
secret, err := token.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate access token: %v", err))
return
}
if err := db.CreateAccessToken(ctx, &token); err != nil {
httpError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
setLoginTokenCookie(w, &token, secret)
http.Redirect(w, req, redirectURI.String(), http.StatusFound)
}