From 69434566ea5a8a0d39802a62cc6b254ea8bd667c Mon Sep 17 00:00:00 2001
From: Runxi Yu
Date: Thu, 25 Sep 2025 19:41:57 +0800
Subject: [PATCH] Implement basic OIDC and some fixes
---
client.go | 25 ++++++++++++++++++++-----
csrf.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++
db.go | 96 +++++++++++++++++++++++++++++++++++++++++++++++------
entity.go | 68 ++++++++++++++++++++++++++++++++++++++++--------------
go.mod | 1 +
go.sum | 8 ++------
main.go | 12 ++++++++++--
middleware.go | 25 +++++++++++++++++++------
oauth2.go | 384 +++++++++++++++++++++++++++++++++++++++++++++++------
oidc.go | 373 +++++++++++++++++++++++++++++++++++++++++++++++++++++
pkce.go | 40 ++++++++++++++++++++++++++++++++++++++++
schema.sql | 21 +++++++++++++++++++--
static/style.css | 8 ++++----
template/authorize.html | 1 +
template/head.html | 6 +++---
template/index.html | 12 +++++++++---
template/login.html | 1 +
template/manage-client.html | 10 ++++++++++
template/manage-user.html | 9 +++++++++
user.go | 18 ++++++++++++------
diff --git a/client.go b/client.go
index 63a914e76ef0393db9d2f8d4f8f7ccec55c568df..6d1c72ce6caad48fcbeef9220c4d0b4e03abc709 100644
--- a/client.go
+++ b/client.go
@@ -44,6 +44,12 @@ return
}
}
+ if normalized, err := normalizeClientPKCERequirement(client.PKCERequirement); err == nil {
+ client.PKCERequirement = normalized
+ } else {
+ client.PKCERequirement = pkceRequirementNone
+ }
+
if req.Method != http.MethodPost {
data := struct {
TemplateBaseData
@@ -51,7 +57,7 @@ Client *Client
}{
Client: client,
}
- tpl.MustExecuteTemplate(w, "manage-client.html", &data)
+ tpl.MustExecuteTemplate(req.Context(), w, "manage-client.html", &data)
return
}
@@ -79,6 +85,17 @@ client.ClientName = req.PostFormValue("client_name")
client.ClientURI = req.PostFormValue("client_uri")
client.RedirectURIs = req.PostFormValue("redirect_uris")
+ pkceRequirement := req.PostFormValue("pkce_requirement")
+ if !isPublic {
+ pkceRequirement = pkceRequirementNone
+ }
+ normalizedRequirement, err := normalizeClientPKCERequirement(pkceRequirement)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ client.PKCERequirement = normalizedRequirement
+
if err := validateAllowedRedirectURIs(client.RedirectURIs); err != nil {
// TODO: nicer error message
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -113,7 +130,7 @@ }{
ClientID: client.ClientID,
ClientSecret: clientSecret,
}
- tpl.MustExecuteTemplate(w, "client-secret.html", &data)
+ tpl.MustExecuteTemplate(req.Context(), w, "client-secret.html", &data)
}
func validateAllowedRedirectURIs(rawRedirectURIs string) error {
@@ -130,9 +147,7 @@ switch u.Scheme {
case "https":
// ok
case "http":
- if u.Host != "localhost" {
- return fmt.Errorf("Only http://localhost is allowed for insecure HTTP URIs")
- }
+ // insecure but let's just trust the admin
default:
if !strings.Contains(u.Scheme, ".") {
return fmt.Errorf("Only private-use URIs referring to domain names are allowed")
diff --git a/csrf.go b/csrf.go
new file mode 100644
index 0000000000000000000000000000000000000000..538484504a78aca42a99a5f8c1dee95bed5bab06
--- /dev/null
+++ b/csrf.go
@@ -0,0 +1,106 @@
+package main
+
+import (
+ "context"
+ "crypto/subtle"
+ "net/http"
+ "strings"
+)
+
+const (
+ csrfCookieName = "vireo-csrf"
+ csrfFormField = "_csrf"
+)
+
+func csrfTokenFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ if token, ok := ctx.Value(contextKeyCSRFToken).(string); ok {
+ return token
+ }
+ return ""
+}
+
+func ensureCSRFToken(w http.ResponseWriter, req *http.Request) (string, error) {
+ if cookie, err := req.Cookie(csrfCookieName); err == nil && cookie.Value != "" {
+ return cookie.Value, nil
+ }
+
+ token, err := generateUID()
+ if err != nil {
+ return "", err
+ }
+
+ http.SetCookie(w, &http.Cookie{
+ Name: csrfCookieName,
+ Value: token,
+ Path: "/",
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ Secure: isForwardedHTTPS(req),
+ })
+
+ return token, nil
+}
+
+func csrfMiddleware(next http.Handler) http.Handler {
+ cop := http.NewCrossOriginProtection()
+ cop.SetDenyHandler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ http.Error(w, "Forbidden", http.StatusForbidden)
+ }))
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+ if isCSRFBypass(req) || req.Header.Get("Authorization") != "" {
+ next.ServeHTTP(w, req)
+ return
+ }
+ token, err := ensureCSRFToken(w, req)
+ if err != nil {
+ httpError(w, err)
+ return
+ }
+
+ ctx := context.WithValue(req.Context(), contextKeyCSRFToken, token)
+ req = req.WithContext(ctx)
+
+ switch req.Method {
+ case http.MethodGet, http.MethodHead, http.MethodOptions:
+ next.ServeHTTP(w, req)
+ return
+ }
+
+ if err := req.ParseForm(); err != nil {
+ http.Error(w, "Bad request", http.StatusBadRequest)
+ return
+ }
+
+ formToken := req.PostFormValue(csrfFormField)
+ if formToken == "" {
+ formToken = req.Header.Get("X-CSRF-Token")
+ }
+ if formToken == "" || subtle.ConstantTimeCompare([]byte(formToken), []byte(token)) != 1 {
+ http.Error(w, "Invalid CSRF token", http.StatusForbidden)
+ return
+ }
+
+ next.ServeHTTP(w, req)
+ })
+
+ return cop.Handler(handler)
+}
+
+func isCSRFBypass(req *http.Request) bool {
+ path := req.URL.Path
+ switch {
+ case path == "/token",
+ path == "/introspect",
+ path == "/revoke",
+ path == "/userinfo",
+ strings.HasPrefix(path, "/static/"),
+ strings.HasPrefix(path, "/.well-known/"),
+ path == "/favicon.ico":
+ return true
+ }
+ return false
+}
diff --git a/db.go b/db.go
index 186f709b73904de64ba14652cd60a9bd9c6a1cd3..d42aaa8c22ce0a2b3f8667b2c748eaf8cd70aa8e 100644
--- a/db.go
+++ b/db.go
@@ -20,6 +20,42 @@ ALTER TABLE AccessToken ADD COLUMN refresh_hash BLOB;
ALTER TABLE AccessToken ADD COLUMN refresh_expires_at datetime;
CREATE UNIQUE INDEX access_token_refresh_hash ON AccessToken(refresh_hash);
`,
+ `
+ ALTER TABLE AuthCode ADD COLUMN nonce TEXT;
+ CREATE TABLE IF NOT EXISTS SigningKey (
+ id INTEGER PRIMARY KEY,
+ kid TEXT NOT NULL UNIQUE,
+ algorithm TEXT NOT NULL,
+ private_key BLOB NOT NULL,
+ created_at datetime NOT NULL
+ );
+ `,
+ `
+ ALTER TABLE AccessToken ADD COLUMN auth_time datetime;
+ ALTER TABLE SigningKey RENAME TO SigningKey_old;
+ CREATE TABLE SigningKey (
+ id INTEGER PRIMARY KEY,
+ kid TEXT NOT NULL UNIQUE,
+ algorithm TEXT NOT NULL,
+ private_key BLOB NOT NULL,
+ created_at datetime NOT NULL
+ );
+ INSERT INTO SigningKey(id, kid, algorithm, private_key, created_at)
+ SELECT id, kid, algorithm, private_key, created_at FROM SigningKey_old;
+ DROP TABLE SigningKey_old;
+ CREATE INDEX IF NOT EXISTS signing_key_created_at ON SigningKey(created_at);
+ `,
+ `
+ ALTER TABLE User ADD COLUMN email TEXT;
+ `,
+ `
+ ALTER TABLE User ADD COLUMN name TEXT;
+ `,
+ `
+ ALTER TABLE AuthCode ADD COLUMN code_challenge TEXT;
+ ALTER TABLE AuthCode ADD COLUMN code_challenge_method TEXT;
+ ALTER TABLE Client ADD COLUMN pkce_requirement TEXT;
+ `,
}
var errNoDBRows = sql.ErrNoRows
@@ -55,7 +91,7 @@ return nil
}
// TODO: drop this
- defaultUser := User{Username: "root", Admin: true}
+ defaultUser := User{Username: "root", Name: "Root User", Email: "root@example.invalid", Admin: true}
if err := defaultUser.SetPassword("root"); err != nil {
return err
}
@@ -126,10 +162,12 @@ }
func (db *DB) StoreUser(ctx context.Context, user *User) error {
return db.db.QueryRowContext(ctx, `
- INSERT INTO User(id, username, password_hash, admin)
- VALUES (:id, :username, :password_hash, :admin)
+ INSERT INTO User(id, username, name, email, password_hash, admin)
+ VALUES (:id, :username, :name, :email, :password_hash, :admin)
ON CONFLICT(id) DO UPDATE SET
username = :username,
+ name = :name,
+ email = :email,
password_hash = :password_hash,
admin = :admin
RETURNING id
@@ -178,16 +216,17 @@
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO Client(id, client_id, client_secret_hash, owner,
- redirect_uris, client_name, client_uri)
+ redirect_uris, client_name, client_uri, pkce_requirement)
VALUES (:id, :client_id, :client_secret_hash, :owner,
- :redirect_uris, :client_name, :client_uri)
+ :redirect_uris, :client_name, :client_uri, :pkce_requirement)
ON CONFLICT(id) DO UPDATE SET
client_id = :client_id,
client_secret_hash = :client_secret_hash,
owner = :owner,
redirect_uris = :redirect_uris,
client_name = :client_name,
- client_uri = :client_uri
+ client_uri = :client_uri,
+ pkce_requirement = :pkce_requirement
RETURNING id
`, entityArgs(client)...).Scan(&client.ID)
}
@@ -264,9 +303,9 @@
func (db *DB) StoreAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AccessToken(id, hash, user, client, scope, issued_at,
- expires_at, refresh_hash, refresh_expires_at)
+ expires_at, auth_time, refresh_hash, refresh_expires_at)
VALUES (:id, :hash, :user, :client, :scope, :issued_at, :expires_at,
- :refresh_hash, :refresh_expires_at)
+ :auth_time, :refresh_hash, :refresh_expires_at)
ON CONFLICT(id) DO UPDATE SET
hash = :hash,
user = :user,
@@ -274,6 +313,7 @@ client = :client,
scope = :scope,
issued_at = :issued_at,
expires_at = :expires_at,
+ auth_time = :auth_time,
refresh_hash = :refresh_hash,
refresh_expires_at = :refresh_expires_at
RETURNING id
@@ -313,8 +353,8 @@ }
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
return db.db.QueryRowContext(ctx, `
- INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
- VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
+ INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri, nonce, code_challenge, code_challenge_method)
+ VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri, :nonce, :code_challenge, :code_challenge_method)
RETURNING id
`, entityArgs(code)...).Scan(&code.ID)
}
@@ -331,6 +371,42 @@ }
var authCode AuthCode
err = scanRow(&authCode, rows)
return &authCode, err
+}
+
+func (db *DB) FetchSigningKeys(ctx context.Context) ([]SigningKey, error) {
+ rows, err := db.db.QueryContext(ctx, `
+ SELECT * FROM SigningKey
+ ORDER BY created_at DESC
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var keys []SigningKey
+ for rows.Next() {
+ var key SigningKey
+ if err := scan(&key, rows); err != nil {
+ return nil, err
+ }
+ keys = append(keys, key)
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ if len(keys) == 0 {
+ return nil, errNoDBRows
+ }
+ return keys, nil
+}
+
+func (db *DB) StoreSigningKey(ctx context.Context, key *SigningKey) error {
+ return db.db.QueryRowContext(ctx, `
+ INSERT INTO SigningKey(kid, algorithm, private_key, created_at)
+ VALUES (:kid, :algorithm, :private_key, :created_at)
+ RETURNING id
+ `, sql.Named("kid", key.KID), sql.Named("algorithm", key.Algorithm), sql.Named("private_key", key.PrivateKey), sql.Named("created_at", key.CreatedAt)).Scan(&key.ID)
}
func (db *DB) Maintain(ctx context.Context) error {
diff --git a/entity.go b/entity.go
index eceb92320ca9d2b89d4e4d28a0d5beb69072f4b0..4edc14a929046db58aa6984ad8992a7df196ea43 100644
--- a/entity.go
+++ b/entity.go
@@ -31,6 +31,7 @@ _ entity = (*User)(nil)
_ entity = (*Client)(nil)
_ entity = (*AccessToken)(nil)
_ entity = (*AuthCode)(nil)
+ _ entity = (*SigningKey)(nil)
)
type ID[T entity] int64
@@ -105,6 +106,8 @@
type User struct {
ID ID[*User]
Username string
+ Name string
+ Email string
PasswordHash string
Admin bool
}
@@ -113,6 +116,8 @@ func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
+ "name": nullValue{&user.Name},
+ "email": nullValue{&user.Email},
"password_hash": nullValue{&user.PasswordHash},
"admin": &user.Admin,
}
@@ -144,6 +149,7 @@ Owner ID[*User]
RedirectURIs string
ClientName string
ClientURI string
+ PKCERequirement string
}
func (client *Client) Generate(isPublic bool) (secret string, err error) {
@@ -174,6 +180,7 @@ "owner": &client.Owner,
"redirect_uris": nullValue{&client.RedirectURIs},
"client_name": nullValue{&client.ClientName},
"client_uri": nullValue{&client.ClientURI},
+ "pkce_requirement": nullValue{&client.PKCERequirement},
}
}
@@ -193,6 +200,7 @@ Client ID[*Client]
Scope string
IssuedAt time.Time
ExpiresAt time.Time
+ AuthTime time.Time
RefreshHash []byte
RefreshExpiresAt time.Time
@@ -221,9 +229,10 @@ }
func NewAccessTokenFromAuthCode(authCode *AuthCode) *AccessToken {
return &AccessToken{
- User: authCode.User,
- Client: authCode.Client,
- Scope: authCode.Scope,
+ User: authCode.User,
+ Client: authCode.Client,
+ Scope: authCode.Scope,
+ AuthTime: authCode.CreatedAt,
}
}
@@ -236,6 +245,7 @@ "client": &token.Client,
"scope": nullValue{&token.Scope},
"issued_at": &token.IssuedAt,
"expires_at": &token.ExpiresAt,
+ "auth_time": nullValue{&token.AuthTime},
"refresh_hash": &token.RefreshHash,
"refresh_expires_at": nullValue{&token.RefreshExpiresAt},
}
@@ -255,13 +265,16 @@ ExpiresAt time.Time
}
type AuthCode struct {
- ID ID[*AuthCode]
- Hash []byte
- CreatedAt time.Time
- User ID[*User]
- Client ID[*Client]
- Scope string
- RedirectURI string
+ ID ID[*AuthCode]
+ Hash []byte
+ CreatedAt time.Time
+ User ID[*User]
+ Client ID[*Client]
+ Scope string
+ RedirectURI string
+ Nonce string
+ CodeChallenge string
+ CodeChallengeMethod string
}
func (code *AuthCode) Generate() (secret string, err error) {
@@ -276,13 +289,16 @@ }
func (code *AuthCode) columns() map[string]interface{} {
return map[string]interface{}{
- "id": &code.ID,
- "hash": &code.Hash,
- "created_at": &code.CreatedAt,
- "user": &code.User,
- "client": &code.Client,
- "scope": nullValue{&code.Scope},
- "redirect_uri": nullValue{&code.RedirectURI},
+ "id": &code.ID,
+ "hash": &code.Hash,
+ "created_at": &code.CreatedAt,
+ "user": &code.User,
+ "client": &code.Client,
+ "scope": nullValue{&code.Scope},
+ "redirect_uri": nullValue{&code.RedirectURI},
+ "nonce": nullValue{&code.Nonce},
+ "code_challenge": nullValue{&code.CodeChallenge},
+ "code_challenge_method": nullValue{&code.CodeChallengeMethod},
}
}
@@ -297,6 +313,24 @@ SecretKindAccessToken = SecretKind('a')
SecretKindRefreshToken = SecretKind('r')
SecretKindAuthCode = SecretKind('c')
)
+
+type SigningKey struct {
+ ID ID[*SigningKey]
+ KID string
+ Algorithm string
+ PrivateKey []byte
+ CreatedAt time.Time
+}
+
+func (key *SigningKey) columns() map[string]interface{} {
+ return map[string]interface{}{
+ "id": &key.ID,
+ "kid": &key.KID,
+ "algorithm": &key.Algorithm,
+ "private_key": &key.PrivateKey,
+ "created_at": &key.CreatedAt,
+ }
+}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
kind, s, _ := strings.Cut(s, ".")
diff --git a/go.mod b/go.mod
index 6d54256cb8f5dc70ee73ef044dc0835ad24dd4e5..995cb68508c65a99e25856df5bc6fd1b25007551 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@ require (
codeberg.org/emersion/go-scfg v0.1.0
github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315
github.com/go-chi/chi/v5 v5.2.3
+ github.com/golang-jwt/jwt/v5 v5.2.1
github.com/mattn/go-sqlite3 v1.14.32
golang.org/x/crypto v0.42.0
)
diff --git a/go.sum b/go.sum
index 3b0c6e5fc5d6e7caca72a293a9f8cb6ede24fa96..30e4827711921e981d2e0bab25d419ff99b2e1c9 100644
--- a/go.sum
+++ b/go.sum
@@ -4,15 +4,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315 h1:sXzwA8yItbg3ji0UuTLkuO4NKPqQJjC035hPoZI40h8=
github.com/emersion/go-oauth2 v0.0.0-20250228145955-eaead4148315/go.mod h1:pSj8CBn/jb+ynRxt/ESIJisazza/Sh2DjwUn31l2tI0=
-github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
-github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
-github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
-github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
+github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
-golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
-golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
diff --git a/main.go b/main.go
index ee309209629eba3ae13ae71a4661c74fd89d8238..c9b938fbb4c571105430a3a6b3a4116e90441468 100644
--- a/main.go
+++ b/main.go
@@ -50,6 +50,11 @@ if err != nil {
log.Fatalf("Failed to load template: %v", err)
}
+ oidcProvider, err := newOIDCProvider(context.Background(), db)
+ if err != nil {
+ log.Fatalf("Failed to initialize OpenID Connect provider: %v", err)
+ }
+
mux := chi.NewRouter()
mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
mux.Get("/", index)
@@ -61,18 +66,21 @@ mux.Post("/client/{id}/revoke", revokeClient)
mux.HandleFunc("/user/new", manageUser)
mux.HandleFunc("/user/{id}", manageUser)
mux.Get("/.well-known/oauth-authorization-server", getOAuthServerMetadata)
+ mux.Get("/.well-known/openid-configuration", getOpenIDConfiguration)
+ mux.Get("/.well-known/jwks.json", getOIDCJWKS)
mux.HandleFunc("/authorize", authorize)
mux.Post("/token", exchangeToken)
mux.Post("/introspect", introspectToken)
mux.Post("/revoke", revokeToken)
+ mux.HandleFunc("/userinfo", userInfo)
go maintainDBLoop(db)
server := http.Server{
Addr: listenAddr,
- Handler: loginTokenMiddleware(mux),
+ Handler: csrfMiddleware(loginTokenMiddleware(mux)),
BaseContext: func(net.Listener) context.Context {
- return newBaseContext(db, tpl)
+ return newBaseContext(db, tpl, oidcProvider)
},
}
log.Printf("OAuth server listening on %v", server.Addr)
diff --git a/middleware.go b/middleware.go
index 3f6132095db3707811d56785f845d1e20bdf1d9f..ca550ffad78caf32ee9391db7a07cd27083c70f9 100644
--- a/middleware.go
+++ b/middleware.go
@@ -21,6 +21,8 @@ const (
contextKeyDB = "db"
contextKeyTemplate = "template"
contextKeyLoginToken = "login-token"
+ contextKeyOIDC = "oidc"
+ contextKeyCSRFToken = "csrf-token"
)
func dbFromContext(ctx context.Context) *DB {
@@ -31,6 +33,10 @@ func templateFromContext(ctx context.Context) *Template {
return ctx.Value(contextKeyTemplate).(*Template)
}
+func oidcProviderFromContext(ctx context.Context) *OIDCProvider {
+ return ctx.Value(contextKeyOIDC).(*OIDCProvider)
+}
+
func loginTokenFromContext(ctx context.Context) *AccessToken {
v := ctx.Value(contextKeyLoginToken)
if v == nil {
@@ -39,10 +45,11 @@ }
return v.(*AccessToken)
}
-func newBaseContext(db *DB, tpl *Template) context.Context {
+func newBaseContext(db *DB, tpl *Template, oidc *OIDCProvider) context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, contextKeyDB, db)
ctx = context.WithValue(ctx, contextKeyTemplate, tpl)
+ ctx = context.WithValue(ctx, contextKeyOIDC, oidc)
return ctx
}
@@ -51,7 +58,7 @@ http.SetCookie(w, &http.Cookie{
Name: loginCookieName,
Value: MarshalSecret(token.ID, SecretKindAccessToken, secret),
HttpOnly: true,
- SameSite: http.SameSiteStrictMode,
+ SameSite: http.SameSiteLaxMode,
Secure: isForwardedHTTPS(req),
})
}
@@ -60,7 +67,7 @@ func unsetLoginTokenCookie(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: loginCookieName,
HttpOnly: true,
- SameSite: http.SameSiteStrictMode,
+ SameSite: http.SameSiteLaxMode,
Secure: isForwardedHTTPS(req),
MaxAge: -1,
})
@@ -114,6 +121,7 @@ }
type TemplateBaseData struct {
ServerName string
+ CSRFToken string
}
func (data *TemplateBaseData) Base() *TemplateBaseData {
@@ -137,11 +145,16 @@ }
return &Template{tpl: tpl, baseData: baseData}, nil
}
-func (tpl *Template) MustExecuteTemplate(w io.Writer, filename string, data TemplateData) {
+func (tpl *Template) MustExecuteTemplate(ctx context.Context, w io.Writer, filename string, data TemplateData) {
+ baseCopy := *tpl.baseData
+ if token := csrfTokenFromContext(ctx); token != "" {
+ baseCopy.CSRFToken = token
+ }
if data == nil {
- data = tpl.baseData
+ base := baseCopy
+ data = &base
} else {
- *data.Base() = *tpl.baseData
+ *data.Base() = baseCopy
}
if err := tpl.tpl.ExecuteTemplate(w, filename, data); err != nil {
panic(err)
diff --git a/oauth2.go b/oauth2.go
index 4654df2e785a464559ae4729b711219d2430e058..cd4fdcbde59bbf91548db913b330e0df97c1ecb1 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -1,6 +1,8 @@
package main
import (
+ "crypto/sha256"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -16,22 +18,151 @@
"github.com/emersion/go-oauth2"
)
+const (
+ scopeOpenID = "openid"
+ scopeProfile = "profile"
+ scopeEmail = "email"
+ scopeOfflineAccess = "offline_access"
+ pkceMethodPlain = "plain"
+ pkceMethodS256 = "S256"
+)
+
+var allowedScopes = map[string]struct{}{
+ scopeOpenID: {},
+ scopeProfile: {},
+ scopeEmail: {},
+ scopeOfflineAccess: {},
+}
+
+type oidcTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType oauth2.TokenType `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ Scope string `json:"scope,omitempty"`
+ IDToken string `json:"id_token,omitempty"`
+}
+
+func parseScopes(scope string) []string {
+ if scope == "" {
+ return nil
+ }
+ parts := strings.Fields(scope)
+ var scopes []string
+ seen := make(map[string]struct{}, len(parts))
+ for _, p := range parts {
+ if p == "" {
+ continue
+ }
+ p = strings.ToLower(p)
+ if _, ok := seen[p]; ok {
+ continue
+ }
+ seen[p] = struct{}{}
+ scopes = append(scopes, p)
+ }
+ return scopes
+}
+
+func normalizeScope(scope string) (string, []string) {
+ scopes := parseScopes(scope)
+ if len(scopes) == 0 {
+ return "", nil
+ }
+ return strings.Join(scopes, " "), scopes
+}
+
+func validateScopes(scopes []string) error {
+ for _, scope := range scopes {
+ if _, ok := allowedScopes[scope]; !ok {
+ return fmt.Errorf("unsupported scope %q", scope)
+ }
+ }
+ return nil
+}
+
+func normalizeCodeChallengeMethod(method string) (string, error) {
+ if method == "" {
+ return pkceMethodPlain, nil
+ }
+ switch {
+ case strings.EqualFold(method, pkceMethodPlain):
+ return pkceMethodPlain, nil
+ case strings.EqualFold(method, pkceMethodS256):
+ return pkceMethodS256, nil
+ default:
+ return "", fmt.Errorf("unsupported code_challenge_method")
+ }
+}
+
+func validateCodeVerifier(verifier string) error {
+ if verifier == "" {
+ return fmt.Errorf("missing code_verifier")
+ }
+ if len(verifier) < 43 || len(verifier) > 128 {
+ return fmt.Errorf("invalid code_verifier length")
+ }
+ for _, r := range verifier {
+ if !(r >= 'A' && r <= 'Z' || r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' || r == '.' || r == '_' || r == '~') {
+ return fmt.Errorf("invalid character in code_verifier")
+ }
+ }
+ return nil
+}
+
+func validateCodeChallenge(method, challenge string) error {
+ if challenge == "" {
+ return fmt.Errorf("missing code_challenge")
+ }
+ if err := validateCodeVerifier(challenge); err != nil {
+ return err
+ }
+ if method != pkceMethodPlain && method != pkceMethodS256 {
+ return fmt.Errorf("unsupported code_challenge_method")
+ }
+ return nil
+}
+
+func verifyCodeVerifier(method, challenge, verifier string) error {
+ if err := validateCodeVerifier(verifier); err != nil {
+ return err
+ }
+ switch method {
+ case "", pkceMethodPlain:
+ if challenge != verifier {
+ return fmt.Errorf("code_verifier mismatch")
+ }
+ case pkceMethodS256:
+ hash := sha256.Sum256([]byte(verifier))
+ expected := base64.RawURLEncoding.EncodeToString(hash[:])
+ if expected != challenge {
+ return fmt.Errorf("code_verifier mismatch")
+ }
+ default:
+ return fmt.Errorf("unsupported code_challenge_method")
+ }
+ return nil
+}
+
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
issuer := getIssuer(req)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
- Issuer: issuer,
- AuthorizationEndpoint: issuer + "/authorize",
- TokenEndpoint: issuer + "/token",
- IntrospectionEndpoint: issuer + "/introspect",
- RevocationEndpoint: issuer + "/revoke",
- ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
- ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
- GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
- TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
+ Issuer: issuer,
+ AuthorizationEndpoint: issuer + "/authorize",
+ TokenEndpoint: issuer + "/token",
+ IntrospectionEndpoint: issuer + "/introspect",
+ RevocationEndpoint: issuer + "/revoke",
+ JWKSURI: issuer + "/.well-known/jwks.json",
+ ScopesSupported: []string{scopeOpenID, scopeProfile, scopeEmail, scopeOfflineAccess},
+ ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
+ ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
+ GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode, oauth2.GrantTypeRefreshToken},
+ TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
+ CodeChallengeMethodsSupported: []string{pkceMethodPlain, pkceMethodS256},
AuthorizationResponseIssParameterSupported: true,
})
}
@@ -65,6 +196,12 @@ clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
+ _, stateProvided := q["state"]
+ codeChallenge := q.Get("code_challenge")
+ codeChallengeMethod := q.Get("code_challenge_method")
+
+ var normalizedCodeChallengeMethod string
+ nonce := q.Get("nonce")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
@@ -80,6 +217,12 @@ httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
+ requiredPKCE, err := normalizeClientPKCERequirement(client.PKCERequirement)
+ if err != nil {
+ httpError(w, fmt.Errorf("invalid PKCE requirement configuration: %v", err))
+ return
+ }
+
var allowedRedirectURIs []*url.URL
for _, s := range strings.Split(client.RedirectURIs, "\n") {
if s == "" {
@@ -112,20 +255,98 @@ }
redirectURI = allowedRedirectURIs[0]
}
+ if codeChallenge != "" {
+ method, err := normalizeCodeChallengeMethod(codeChallengeMethod)
+ if err != nil {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: err.Error(),
+ })
+ return
+ }
+ if err := validateCodeChallenge(method, codeChallenge); err != nil {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: err.Error(),
+ })
+ return
+ }
+ normalizedCodeChallengeMethod = method
+ }
+
+ if codeChallenge == "" && codeChallengeMethod != "" {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "code_challenge_method without code_challenge",
+ })
+ return
+ }
+
+ switch requiredPKCE {
+ case pkceRequirementPlain:
+ if codeChallenge == "" {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "PKCE is required",
+ })
+ return
+ }
+ if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementPlain) {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "PKCE method does not satisfy requirement",
+ })
+ return
+ }
+ case pkceRequirementS256:
+ if codeChallenge == "" {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "PKCE (S256) is required",
+ })
+ return
+ }
+ if !allowPKCERequirement(normalizedCodeChallengeMethod, pkceRequirementS256) {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "PKCE (S256) is required",
+ })
+ return
+ }
+ }
+
+ codeChallengeMethod = normalizedCodeChallengeMethod
+
if respType != oauth2.ResponseTypeCode {
- redirectClientError(w, req, redirectURI, state, &oauth2.Error{
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType,
})
return
}
- // TODO: add support for scope
- if scope != "" {
- redirectClientError(w, req, redirectURI, state, &oauth2.Error{
- Code: oauth2.ErrorCodeInvalidScope,
+ normalizedScope, scopes := normalizeScope(scope)
+ if len(scopes) == 0 {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidScope,
+ Description: "Missing required openid scope",
})
return
}
+ if err := validateScopes(scopes); err != nil {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidScope,
+ Description: err.Error(),
+ })
+ return
+ }
+ if !containsScope(scopes, scopeOpenID) {
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidScope,
+ Description: "Scope openid is required",
+ })
+ return
+ }
+ scope = normalizedScope
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
@@ -141,7 +362,7 @@ }
_ = req.ParseForm()
if _, ok := req.PostForm["deny"]; ok {
- redirectClientError(w, req, redirectURI, state, &oauth2.Error{
+ redirectClientError(w, req, redirectURI, state, stateProvided, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
})
return
@@ -153,15 +374,18 @@ Client *Client
}{
Client: client,
}
- tpl.MustExecuteTemplate(w, "authorize.html", &data)
+ tpl.MustExecuteTemplate(req.Context(), w, "authorize.html", &data)
return
}
authCode := AuthCode{
- User: loginToken.User,
- Client: client.ID,
- Scope: scope,
- RedirectURI: rawRedirectURI,
+ User: loginToken.User,
+ Client: client.ID,
+ Scope: scope,
+ RedirectURI: rawRedirectURI,
+ Nonce: nonce,
+ CodeChallenge: codeChallenge,
+ CodeChallengeMethod: codeChallengeMethod,
}
secret, err := authCode.Generate()
if err != nil {
@@ -178,7 +402,7 @@ code := MarshalSecret(authCode.ID, SecretKindAuthCode, secret)
values := make(url.Values)
values.Set("code", code)
- if state != "" {
+ if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
@@ -200,6 +424,7 @@
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
+ codeVerifier := values.Get("code_verifier")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" {
@@ -235,7 +460,13 @@ return
}
}
- var token *AccessToken
+ var (
+ token *AccessToken
+ authorizationCode *AuthCode
+ currentClient *Client
+ nonceValue string
+ )
+
switch grantType {
case oauth2.GrantTypeAuthorizationCode:
if client == nil {
@@ -247,8 +478,8 @@ return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
- authCode, err := db.PopAuthCode(ctx, codeID)
- if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
+ authorizationCode, err = db.PopAuthCode(ctx, codeID)
+ if err == errNoDBRows || (err == nil && !authorizationCode.VerifySecret(codeSecret)) || authorizationCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
@@ -259,22 +490,46 @@ oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
- if scope != authCode.Scope {
+ if scope != "" && scope != authorizationCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
- if values.Get("redirect_uri") != authCode.RedirectURI {
+ if values.Get("redirect_uri") != authorizationCode.RedirectURI {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid redirect URI",
})
return
}
+ if authorizationCode.CodeChallenge != "" {
+ if codeVerifier == "" {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "Missing code_verifier",
+ })
+ return
+ }
+ if err := verifyCodeVerifier(authorizationCode.CodeChallengeMethod, authorizationCode.CodeChallenge, codeVerifier); err != nil {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidGrant,
+ Description: "Invalid code_verifier",
+ })
+ return
+ }
+ } else if codeVerifier != "" {
+ oauthError(w, &oauth2.Error{
+ Code: oauth2.ErrorCodeInvalidRequest,
+ Description: "Unexpected code_verifier",
+ })
+ return
+ }
- token = NewAccessTokenFromAuthCode(authCode)
+ token = NewAccessTokenFromAuthCode(authorizationCode)
+ currentClient = client
+ nonceValue = authorizationCode.Nonce
case oauth2.GrantTypeRefreshToken:
tokenID, refreshSecret, _ := UnmarshalSecret[*AccessToken](values.Get("refresh_token"))
token, err = db.FetchAccessToken(ctx, tokenID)
@@ -297,13 +552,13 @@ })
return
}
- tokenClient, err := db.FetchClient(ctx, token.Client)
+ currentClient, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
- if !tokenClient.IsPublic() && client == nil {
+ if !currentClient.IsPublic() && client == nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
@@ -311,7 +566,7 @@ })
return
}
- if scope != token.Scope {
+ if scope != "" && scope != token.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
@@ -330,10 +585,19 @@ if err != nil {
oauthError(w, err)
return
}
- refreshSecret, err := token.GenerateRefresh()
- if err != nil {
- oauthError(w, err)
- return
+
+ tokenScopes := parseScopes(token.Scope)
+ issueRefresh := containsScope(tokenScopes, scopeOfflineAccess)
+ var refreshSecret string
+ if issueRefresh {
+ refreshSecret, err = token.GenerateRefresh()
+ if err != nil {
+ oauthError(w, err)
+ return
+ }
+ } else {
+ token.RefreshHash = nil
+ token.RefreshExpiresAt = time.Time{}
}
if err := db.StoreAccessToken(ctx, token); err != nil {
@@ -341,15 +605,49 @@ oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
+ accessTokenValue := MarshalSecret(token.ID, SecretKindAccessToken, secret)
+ if token.AuthTime.IsZero() {
+ token.AuthTime = token.IssuedAt
+ }
+
+ var idToken string
+ if containsScope(tokenScopes, scopeOpenID) {
+ if currentClient == nil {
+ currentClient, err = db.FetchClient(ctx, token.Client)
+ if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
+ return
+ }
+ }
+ user, err := db.FetchUser(ctx, token.User)
+ if err != nil {
+ oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
+ return
+ }
+
+ issuer := getIssuer(req)
+ oidcProvider := oidcProviderFromContext(ctx)
+ idToken, err = oidcProvider.MintIDToken(issuer, currentClient, user, token, tokenScopes, nonceValue, accessTokenValue, token.AuthTime)
+ if err != nil {
+ oauthError(w, fmt.Errorf("failed to mint ID token: %v", err))
+ return
+ }
+ }
+
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
- json.NewEncoder(w).Encode(&oauth2.TokenResp{
- AccessToken: MarshalSecret(token.ID, SecretKindAccessToken, secret),
- TokenType: oauth2.TokenTypeBearer,
- ExpiresIn: time.Until(token.ExpiresAt),
- Scope: strings.Split(token.Scope, " "),
- RefreshToken: MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret),
- })
+ resp := oidcTokenResponse{
+ AccessToken: accessTokenValue,
+ TokenType: oauth2.TokenTypeBearer,
+ ExpiresIn: int64(time.Until(token.ExpiresAt).Seconds()),
+ Scope: token.Scope,
+ IDToken: idToken,
+ }
+ if issueRefresh {
+ refreshTokenValue := MarshalSecret(token.ID, SecretKindRefreshToken, refreshSecret)
+ resp.RefreshToken = refreshTokenValue
+ }
+ json.NewEncoder(w).Encode(&resp)
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
@@ -542,7 +840,7 @@
http.Redirect(w, req, u.String(), http.StatusFound)
}
-func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
+func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, stateProvided bool, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
@@ -557,7 +855,7 @@ }
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
- if state != "" {
+ if stateProvided {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
diff --git a/oidc.go b/oidc.go
new file mode 100644
index 0000000000000000000000000000000000000000..3dfbe9eeb09ece37b41b67b38d6c28244e1d7ee1
--- /dev/null
+++ b/oidc.go
@@ -0,0 +1,373 @@
+package main
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/json"
+ "encoding/pem"
+ "fmt"
+ "math/big"
+ "net/http"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ oauth2 "github.com/emersion/go-oauth2"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+type OIDCProvider struct {
+ signingKeys []*oidcSigningKey
+}
+
+type oidcSigningKey struct {
+ key *SigningKey
+ private *rsa.PrivateKey
+ publicJWK jwk
+}
+
+type jwk struct {
+ Kty string `json:"kty"`
+ Use string `json:"use,omitempty"`
+ Alg string `json:"alg,omitempty"`
+ Kid string `json:"kid,omitempty"`
+ N string `json:"n,omitempty"`
+ E string `json:"e,omitempty"`
+}
+
+type jwks struct {
+ Keys []jwk `json:"keys"`
+}
+
+const idTokenTTL = 15 * time.Minute
+
+func newOIDCProvider(ctx context.Context, db *DB) (*OIDCProvider, error) {
+ signingRecords, err := db.FetchSigningKeys(ctx)
+ if err == errNoDBRows {
+ generated, genErr := generateSigningKey()
+ if genErr != nil {
+ return nil, genErr
+ }
+ if storeErr := db.StoreSigningKey(ctx, generated); storeErr != nil {
+ return nil, fmt.Errorf("failed to persist signing key: %w", storeErr)
+ }
+ signingRecords = []SigningKey{*generated}
+ } else if err != nil {
+ return nil, fmt.Errorf("failed to fetch signing keys: %w", err)
+ }
+
+ signingKeys := make([]*oidcSigningKey, 0, len(signingRecords))
+ for i := range signingRecords {
+ material, convErr := toOIDCSigningKey(&signingRecords[i])
+ if convErr != nil {
+ return nil, convErr
+ }
+ signingKeys = append(signingKeys, material)
+ }
+
+ return &OIDCProvider{signingKeys: signingKeys}, nil
+}
+
+func generateSigningKey() (*SigningKey, error) {
+ priv, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate signing key: %w", err)
+ }
+
+ pemBlock := pem.EncodeToMemory(&pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(priv),
+ })
+ if pemBlock == nil {
+ return nil, fmt.Errorf("failed to encode signing key")
+ }
+
+ kid, err := generateUID()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate signing key ID: %w", err)
+ }
+ return &SigningKey{
+ KID: kid,
+ Algorithm: "RS256",
+ PrivateKey: pemBlock,
+ CreatedAt: time.Now(),
+ }, nil
+}
+
+func toOIDCSigningKey(signing *SigningKey) (*oidcSigningKey, error) {
+ block, _ := pem.Decode(signing.PrivateKey)
+ if block == nil {
+ return nil, fmt.Errorf("failed to decode signing key PEM")
+ }
+ priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse signing key: %w", err)
+ }
+
+ jwk := jwk{
+ Kty: "RSA",
+ Use: "sig",
+ Alg: signing.Algorithm,
+ Kid: signing.KID,
+ N: base64.RawURLEncoding.EncodeToString(priv.N.Bytes()),
+ }
+
+ e := big.NewInt(int64(priv.E)).Bytes()
+ jwk.E = base64.RawURLEncoding.EncodeToString(e)
+
+ return &oidcSigningKey{
+ key: signing,
+ private: priv,
+ publicJWK: jwk,
+ }, nil
+}
+
+func (op *OIDCProvider) currentSigningKey() *oidcSigningKey {
+ if len(op.signingKeys) == 0 {
+ return nil
+ }
+ return op.signingKeys[0]
+}
+
+func (op *OIDCProvider) signingMethod() (*jwt.SigningMethodRSA, *oidcSigningKey, error) {
+ key := op.currentSigningKey()
+ if key == nil {
+ return nil, nil, fmt.Errorf("no signing key configured")
+ }
+
+ switch key.key.Algorithm {
+ case "RS256":
+ return jwt.SigningMethodRS256, key, nil
+ default:
+ return nil, nil, fmt.Errorf("unsupported signing algorithm %q", key.key.Algorithm)
+ }
+}
+
+func (op *OIDCProvider) MintIDToken(issuer string, client *Client, user *User, token *AccessToken, scopes []string, nonce string, accessToken string, authTime time.Time) (string, error) {
+ method, signingKey, err := op.signingMethod()
+ if err != nil {
+ return "", err
+ }
+
+ now := time.Now()
+ expiresAt := now.Add(idTokenTTL)
+ if token.ExpiresAt.Before(expiresAt) {
+ expiresAt = token.ExpiresAt
+ }
+ if expiresAt.Before(now) {
+ expiresAt = now
+ }
+
+ claims := jwt.MapClaims{
+ "iss": issuer,
+ "sub": subjectForUser(user),
+ "aud": client.ClientID,
+ "exp": jwt.NewNumericDate(expiresAt),
+ "iat": jwt.NewNumericDate(now),
+ }
+ if !authTime.IsZero() {
+ claims["auth_time"] = jwt.NewNumericDate(authTime)
+ }
+ if nonce != "" {
+ claims["nonce"] = nonce
+ }
+ if accessToken != "" {
+ claims["at_hash"] = computeAtHash(accessToken)
+ }
+ if containsScope(scopes, scopeProfile) {
+ displayName := user.Name
+ if displayName == "" {
+ displayName = user.Username
+ }
+ claims["preferred_username"] = user.Username
+ claims["name"] = displayName
+ }
+ if containsScope(scopes, scopeEmail) && user.Email != "" {
+ claims["email"] = user.Email
+ claims["email_verified"] = false
+ }
+
+ tokenJWT := jwt.NewWithClaims(method, claims)
+ tokenJWT.Header["kid"] = signingKey.key.KID
+
+ return tokenJWT.SignedString(signingKey.private)
+}
+
+func (op *OIDCProvider) JWKS() jwks {
+ keys := make([]jwk, 0, len(op.signingKeys))
+ for _, key := range op.signingKeys {
+ keys = append(keys, key.publicJWK)
+ }
+ return jwks{Keys: keys}
+}
+
+func subjectForUser(user *User) string {
+ return strconv.FormatInt(int64(user.ID), 10)
+}
+
+func computeAtHash(accessToken string) string {
+ sum := sha256.Sum256([]byte(accessToken))
+ return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2])
+}
+
+func containsScope(scopes []string, scope string) bool {
+ for _, s := range scopes {
+ if strings.EqualFold(s, scope) {
+ return true
+ }
+ }
+ return false
+}
+
+func getOpenIDConfiguration(w http.ResponseWriter, req *http.Request) {
+ ctx := req.Context()
+ oidc := oidcProviderFromContext(ctx)
+ issuer := getIssuer(req)
+ currentKey := oidc.currentSigningKey()
+
+ scopes := make([]string, 0, len(allowedScopes))
+ for scope := range allowedScopes {
+ scopes = append(scopes, scope)
+ }
+ sort.Strings(scopes)
+
+ idTokenAlgs := []string{"RS256"}
+ if currentKey != nil && currentKey.key.Algorithm != "" {
+ idTokenAlgs = []string{currentKey.key.Algorithm}
+ }
+
+ config := map[string]interface{}{
+ "issuer": issuer,
+ "authorization_endpoint": issuer + "/authorize",
+ "token_endpoint": issuer + "/token",
+ "userinfo_endpoint": issuer + "/userinfo",
+ "jwks_uri": issuer + "/.well-known/jwks.json",
+ "response_types_supported": []string{string(oauth2.ResponseTypeCode)},
+ "response_modes_supported": []string{string(oauth2.ResponseModeQuery)},
+ "grant_types_supported": []string{string(oauth2.GrantTypeAuthorizationCode), string(oauth2.GrantTypeRefreshToken)},
+ "subject_types_supported": []string{"public"},
+ "id_token_signing_alg_values_supported": idTokenAlgs,
+ "scopes_supported": scopes,
+ "claims_supported": []string{"sub", "preferred_username", "name", "email", "email_verified"},
+ "token_endpoint_auth_methods_supported": []string{string(oauth2.AuthMethodNone), string(oauth2.AuthMethodClientSecretBasic)},
+ "introspection_endpoint": issuer + "/introspect",
+ "revocation_endpoint": issuer + "/revoke",
+ "authorization_response_iss_parameter_supported": true,
+ "claims_parameter_supported": false,
+ "request_parameter_supported": false,
+ "request_uri_parameter_supported": false,
+ "code_challenge_methods_supported": []string{pkceMethodPlain, pkceMethodS256},
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(config)
+}
+
+func getOIDCJWKS(w http.ResponseWriter, req *http.Request) {
+ ctx := req.Context()
+ oidc := oidcProviderFromContext(ctx)
+
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Cache-Control", "public, max-age=300")
+ json.NewEncoder(w).Encode(oidc.JWKS())
+}
+
+func userInfo(w http.ResponseWriter, req *http.Request) {
+ ctx := req.Context()
+ db := dbFromContext(ctx)
+
+ if req.Method != http.MethodGet && req.Method != http.MethodPost {
+ w.Header().Set("Allow", "GET, POST")
+ http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
+ return
+ }
+
+ tokenValue, err := bearerTokenFromRequest(req)
+ if err != nil {
+ writeBearerError(w, http.StatusUnauthorized, "invalid_token", err.Error())
+ return
+ }
+
+ tokenID, secret, err := UnmarshalSecret[*AccessToken](tokenValue)
+ if err != nil {
+ writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Malformed access token")
+ return
+ }
+
+ token, err := db.FetchAccessToken(ctx, tokenID)
+ if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
+ writeBearerError(w, http.StatusUnauthorized, "invalid_token", "Invalid access token")
+ return
+ } else if err != nil {
+ httpError(w, fmt.Errorf("failed to fetch access token: %v", err))
+ return
+ }
+
+ scopes := parseScopes(token.Scope)
+ if !containsScope(scopes, scopeOpenID) {
+ writeBearerError(w, http.StatusForbidden, "insufficient_scope", "Scope openid missing")
+ return
+ }
+
+ user, err := db.FetchUser(ctx, token.User)
+ if err != nil {
+ httpError(w, fmt.Errorf("failed to fetch user: %v", err))
+ return
+ }
+
+ resp := map[string]interface{}{
+ "sub": subjectForUser(user),
+ }
+ if containsScope(scopes, scopeProfile) {
+ displayName := user.Name
+ if displayName == "" {
+ displayName = user.Username
+ }
+ resp["preferred_username"] = user.Username
+ resp["name"] = displayName
+ }
+ if containsScope(scopes, scopeEmail) && user.Email != "" {
+ resp["email"] = user.Email
+ resp["email_verified"] = false
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+}
+
+func bearerTokenFromRequest(req *http.Request) (string, error) {
+ authz := req.Header.Get("Authorization")
+ if authz == "" {
+ return "", fmt.Errorf("Authorization header missing")
+ }
+ if len(authz) < 7 || !strings.EqualFold(authz[:7], "Bearer ") {
+ return "", fmt.Errorf("Unsupported authorization scheme")
+ }
+ token := strings.TrimSpace(authz[7:])
+ if token == "" {
+ return "", fmt.Errorf("Missing access token")
+ }
+ return token, nil
+}
+
+func writeBearerError(w http.ResponseWriter, status int, code, description string) {
+ challenge := "Bearer"
+ if code != "" {
+ challenge += fmt.Sprintf(" error=\"%s\"", code)
+ }
+ if description != "" {
+ if code == "" {
+ challenge += " "
+ } else {
+ challenge += ", "
+ }
+ challenge += fmt.Sprintf("error_description=\"%s\"", description)
+ }
+ w.Header().Set("WWW-Authenticate", challenge)
+ http.Error(w, http.StatusText(status), status)
+}
diff --git a/pkce.go b/pkce.go
new file mode 100644
index 0000000000000000000000000000000000000000..a04486f8df385de65ce5f6dd4043f2f4821dabff
--- /dev/null
+++ b/pkce.go
@@ -0,0 +1,40 @@
+package main
+
+import (
+ "fmt"
+ "strings"
+)
+
+const (
+ pkceRequirementNone = ""
+ pkceRequirementPlain = pkceMethodPlain
+ pkceRequirementS256 = pkceMethodS256
+)
+
+func normalizeClientPKCERequirement(value string) (string, error) {
+ switch strings.ToUpper(strings.TrimSpace(value)) {
+ case "", "NONE":
+ return pkceRequirementNone, nil
+ case strings.ToUpper(pkceRequirementPlain):
+ return pkceRequirementPlain, nil
+ case pkceRequirementS256:
+ return pkceRequirementS256, nil
+ default:
+ return "", fmt.Errorf("invalid PKCE requirement")
+ }
+}
+
+func allowPKCERequirement(method, requirement string) bool {
+ requirement = strings.ToUpper(requirement)
+ method = strings.ToUpper(method)
+ switch requirement {
+ case "", "NONE":
+ return true
+ case strings.ToUpper(pkceRequirementPlain):
+ return method == strings.ToUpper(pkceMethodPlain) || method == strings.ToUpper(pkceMethodS256)
+ case pkceRequirementS256:
+ return method == strings.ToUpper(pkceMethodS256)
+ default:
+ return false
+ }
+}
diff --git a/schema.sql b/schema.sql
index 9ebc93035d8c818bf60d07295bffbd50931d82d7..92307df080034579691633de40ec67a82e0fe0c0 100644
--- a/schema.sql
+++ b/schema.sql
@@ -1,6 +1,8 @@
CREATE TABLE User (
id INTEGER PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
+ name TEXT,
+ email TEXT,
password_hash TEXT,
admin INTEGER NOT NULL DEFAULT 0
);
@@ -12,7 +14,8 @@ client_secret_hash BLOB,
owner INTEGER REFERENCES User(id) ON DELETE CASCADE,
redirect_uris TEXT,
client_name TEXT,
- client_uri TEXT
+ client_uri TEXT,
+ pkce_requirement TEXT
);
CREATE TABLE AccessToken (
@@ -23,6 +26,7 @@ client INTEGER REFERENCES Client(id) ON DELETE CASCADE,
scope TEXT,
issued_at datetime NOT NULL,
expires_at datetime NOT NULL,
+ auth_time datetime,
refresh_hash BLOB UNIQUE,
refresh_expires_at datetime
);
@@ -34,5 +38,18 @@ created_at datetime NOT NULL,
user INTEGER NOT NULL REFERENCES User(id) ON DELETE CASCADE,
client INTEGER NOT NULL REFERENCES Client(id) ON DELETE CASCADE,
redirect_uri TEXT,
- scope TEXT
+ scope TEXT,
+ nonce TEXT,
+ code_challenge TEXT,
+ code_challenge_method TEXT
+);
+
+CREATE TABLE SigningKey (
+ id INTEGER PRIMARY KEY,
+ kid TEXT NOT NULL UNIQUE,
+ algorithm TEXT NOT NULL,
+ private_key BLOB NOT NULL,
+ created_at datetime NOT NULL
);
+
+CREATE INDEX signing_key_created_at ON SigningKey(created_at);
diff --git a/static/style.css b/static/style.css
index 209f408813aff10ec4295ee3b550263a64c897aa..b14a861e763fe813d68b232f299d57039edb62d0 100644
--- a/static/style.css
+++ b/static/style.css
@@ -59,14 +59,14 @@ background-color: rgb(0, 150, 0);
border-color: rgb(0, 150, 0);
}
-input[type="text"], input[type="password"], input[type="url"], textarea {
+input[type="text"], input[type="email"], input[type="password"], input[type="url"], textarea {
border: 1px solid rgb(208, 210, 215);
border-radius: 4px;
padding: 6px;
margin: 4px 0;
color: #444;
}
-input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
+input[type="email"]:focus, input[type="text"]:focus, input[type="password"]:focus, input[type="url"]:focus, textarea:focus {
outline: none;
border-color: rgb(0, 128, 0);
}
@@ -75,7 +75,7 @@ label {
display: block;
margin: 15px 0;
}
-label input[type="text"], label input[type="password"], label input[type="url"] {
+label input[type="email"], label input[type="text"], label input[type="password"], label input[type="url"] {
display: block;
width: 100%;
max-width: 350px;
@@ -118,7 +118,7 @@ button:hover {
background-color: rgba(255, 255, 255, 0.02);
}
- input[type="text"], input[type="password"], input[type="url"], textarea {
+ input[type="email"], input[type="text"], input[type="password"], input[type="url"], textarea {
background-color: rgba(255, 255, 255, 0.05);
color: inherit;
}
diff --git a/template/authorize.html b/template/authorize.html
index 3146dc9877f1e2e44a5b2e59e1fa8458d5cc33f0..a8ebfe916591bb4c5a6fbba0bae7031ed74fafce 100644
--- a/template/authorize.html
+++ b/template/authorize.html
@@ -21,6 +21,7 @@ ?
diff --git a/template/head.html b/template/head.html
index a36f23450c00bee71bba1c80d9883e155ef8025f..702cff01cd395108fe807766facca2b14714d700 100644
--- a/template/head.html
+++ b/template/head.html
@@ -1,9 +1,9 @@
-
+
{{ .ServerName }}
-
-
+
+
diff --git a/template/index.html b/template/index.html
index ddd00bbec5810994b5910c58f2bf17a9e4948d63..48dd2a577615fc1d82c3a48341ce4e9f02427f54 100644
--- a/template/index.html
+++ b/template/index.html
@@ -6,6 +6,7 @@
Welcome, {{ .Me.Username }}!
@@ -38,9 +39,10 @@ {{ end }}
{{ .ExpiresAt }} |
-
+
|
{{ end }}
@@ -82,11 +84,15 @@
| Username |
+ Name |
+ Email |
Role |
{{ range .Users }}
| {{ .Username }} |
+ {{ .Name }} |
+ {{ .Email }} |
{{ if .Admin }}
Administrator
diff --git a/template/login.html b/template/login.html
index babcf43987c51756d117469441d3b5e712e5e911..2ccc52ecf2cd6206b7b7c6ac8ff9aa06912a9d5a 100644
--- a/template/login.html
+++ b/template/login.html
@@ -5,6 +5,7 @@
{{ .ServerName }}
|