Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Check redirect URI param in token endpoint
package main
import (
"context"
"database/sql"
_ "embed"
"fmt"
"time"
"github.com/mattn/go-sqlite3"
)
//go:embed schema.sql
var schema string
var errNoDBRows = sql.ErrNoRows
type DB struct {
db *sql.DB
}
func openDB(filename string) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
db := &DB{sqlDB}
if err := db.init(context.TODO()); err != nil {
db.Close()
return nil, err
}
return db, nil
}
func (db *DB) init(ctx context.Context) error {
var n int
if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
return err
} else if n != 0 {
return nil
}
if _, err := db.db.ExecContext(ctx, schema); err != nil {
return err
}
// TODO: drop this
defaultUser := User{Username: "root"}
if err := defaultUser.SetPassword("root"); err != nil {
return err
}
return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO User(id, username, password_hash)
VALUES (:id, :username, :password_hash)
ON CONFLICT(id) DO UPDATE SET
username = :username,
password_hash = :password_hash
RETURNING id
`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var l []User
for rows.Next() {
var user User
if err := scan(&user, rows); err != nil {
return nil, err
}
l = append(l, user)
}
return l, rows.Close()
}
func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO Client(id, client_id, client_secret_hash, owner,
redirect_uris, client_name, client_uri)
VALUES (:id, :client_id, :client_secret_hash, :owner,
:redirect_uris, :client_name, :client_uri)
ON CONFLICT(id) DO UPDATE SET
client_id = :client_id,
client_secret_hash = :client_secret_hash,
owner = :owner,
redirect_uris = :redirect_uris,
client_name = :client_name,
client_uri = :client_uri
RETURNING id
`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
if err != nil {
return nil, err
}
defer rows.Close()
var l []Client
for rows.Next() {
var client Client
if err := scan(&client, rows); err != nil {
return nil, err
}
l = append(l, client)
}
return l, rows.Close()
}
func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
rows, err := db.db.QueryContext(ctx, `
SELECT id, client_id, client_name, client_uri, token.expires_at
FROM Client,
(
SELECT client, MAX(expires_at) as expires_at
FROM AccessToken
WHERE user = ?
GROUP BY client
) AS token
WHERE Client.id = token.client
`, user)
if err != nil {
return nil, err
}
var l []AuthorizedClient
for rows.Next() {
var authClient AuthorizedClient
columns := authClient.Client.columns()
var expiresAt string
err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
if err != nil {
return nil, err
}
authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
if err != nil {
return nil, err
}
l = append(l, authClient)
}
return l, rows.Close()
}
func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
return err
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
if err != nil {
return nil, err
}
var token AccessToken
err = scanRow(&token, rows)
return &token, err
}
func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
RETURNING id
`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
return err
}
func (db *DB) RevokeAccessTokens(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE client = ? AND user = ?
`, clientID, userID)
return err
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AuthCode(hash, created_at, user, client, scope) VALUES (:hash, :created_at, :user, :client, :scope)
INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri) VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
RETURNING id
`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
rows, err := db.db.QueryContext(ctx, `
DELETE FROM AuthCode
WHERE id = ?
RETURNING *
`, id)
if err != nil {
return nil, err
}
var authCode AuthCode
err = scanRow(&authCode, rows)
return &authCode, err
}
func (db *DB) Maintain(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE timediff('now', expires_at) > 0
`)
if err != nil {
return err
}
_, err = db.db.ExecContext(ctx, `
DELETE FROM AuthCode
WHERE timediff(?, created_at) > 0
`, time.Now().Add(-authCodeExpiration))
if err != nil {
return err
}
return nil
}
func scan(e entity, rows *sql.Rows) error {
columns := e.columns()
keys, err := rows.Columns()
if err != nil {
panic(err)
}
out := make([]interface{}, len(keys))
for i, k := range keys {
v, ok := columns[k]
if !ok {
panic(fmt.Errorf("unknown column %q", k))
}
out[i] = v
}
return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
if !rows.Next() {
return sql.ErrNoRows
}
if err := scan(e, rows); err != nil {
return err
}
return rows.Close()
}
func entityArgs(e entity) []interface{} {
columns := e.columns()
l := make([]interface{}, 0, len(columns))
for k, v := range columns {
l = append(l, sql.Named(k, v))
}
return l
}
package main
import (
"crypto/rand"
"crypto/sha512"
"crypto/subtle"
"database/sql"
"database/sql/driver"
"encoding/base64"
"fmt"
"strconv"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenExpiration = 30 * 24 * time.Hour
authCodeExpiration = 10 * time.Minute
)
type entity interface {
columns() map[string]interface{}
}
var (
_ entity = (*User)(nil)
_ entity = (*Client)(nil)
_ entity = (*AccessToken)(nil)
_ entity = (*AuthCode)(nil)
)
type ID[T entity] int64
var (
_ sql.Scanner = (*ID[*User])(nil)
_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
u, _ := strconv.ParseUint(s, 10, 63)
if u == 0 {
return 0, fmt.Errorf("invalid ID")
}
return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
if v == nil {
*ptr = 0
return nil
}
id, ok := v.(int64)
if !ok {
return fmt.Errorf("cannot scan ID from %T", v)
}
*ptr = ID[T](id)
return nil
}
func (id ID[T]) Value() (driver.Value, error) {
if id == 0 {
return nil, nil
} else {
return int64(id), nil
}
}
type nullString string
var (
_ sql.Scanner = (*nullString)(nil)
_ driver.Valuer = (*nullString)(nil)
)
func (ptr *nullString) Scan(v interface{}) error {
if v == nil {
*ptr = ""
return nil
}
s, ok := v.(string)
if !ok {
return fmt.Errorf("cannot scan nullStringPtr from %T", v)
}
*ptr = nullString(s)
return nil
}
func (ptr *nullString) Value() (driver.Value, error) {
if *ptr == "" {
return nil, nil
} else {
return string(*ptr), nil
}
}
type User struct {
ID ID[*User]
Username string
PasswordHash string
Admin bool
}
func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
"password_hash": (*nullString)(&user.PasswordHash),
"admin": &user.Admin,
}
}
func (user *User) VerifyPassword(password string) error {
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hash)
return nil
}
func (user *User) PasswordNeedsRehash() bool {
cost, _ := bcrypt.Cost([]byte(user.PasswordHash))
return cost != bcrypt.DefaultCost
}
type Client struct {
ID ID[*Client]
ClientID string
ClientSecretHash []byte
Owner ID[*User]
RedirectURIs string
ClientName string
ClientURI string
}
func (client *Client) Generate(isPublic bool) (secret string, err error) {
id, err := generateUID()
if err != nil {
return "", fmt.Errorf("failed to generate client ID: %v", err)
}
client.ClientID = id
if !isPublic {
var hash []byte
secret, hash, err = generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate client secret: %v", err)
}
client.ClientSecretHash = hash
}
return secret, nil
}
func (client *Client) columns() map[string]interface{} {
return map[string]interface{}{
"id": &client.ID,
"client_id": &client.ClientID,
"client_secret_hash": &client.ClientSecretHash,
"owner": &client.Owner,
"redirect_uris": (*nullString)(&client.RedirectURIs),
"client_name": (*nullString)(&client.ClientName),
"client_uri": (*nullString)(&client.ClientURI),
}
}
func (client *Client) VerifySecret(secret string) bool {
return verifyHash(client.ClientSecretHash, secret)
}
func (client *Client) IsPublic() bool {
return client.ClientSecretHash == nil
}
type AccessToken struct {
ID ID[*AccessToken]
Hash []byte
User ID[*User]
Client ID[*Client]
Scope string
IssuedAt time.Time
ExpiresAt time.Time
}
func (token *AccessToken) Generate(expiration time.Duration) (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate access token secret: %v", err)
}
token.Hash = hash
token.IssuedAt = time.Now()
token.ExpiresAt = time.Now().Add(expiration)
return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
token = &AccessToken{
User: authCode.User,
Client: authCode.Client,
Scope: authCode.Scope,
}
secret, err = token.Generate(accessTokenExpiration)
return token, secret, err
}
func (token *AccessToken) columns() map[string]interface{} {
return map[string]interface{}{
"id": &token.ID,
"hash": &token.Hash,
"user": &token.User,
"client": &token.Client,
"scope": (*nullString)(&token.Scope),
"issued_at": &token.IssuedAt,
"expires_at": &token.ExpiresAt,
}
}
func (token *AccessToken) VerifySecret(secret string) bool {
return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
type AuthorizedClient struct {
Client Client
ExpiresAt time.Time
}
type AuthCode struct {
ID ID[*AuthCode] Hash []byte CreatedAt time.Time User ID[*User] Client ID[*Client] Scope string
ID ID[*AuthCode] Hash []byte CreatedAt time.Time User ID[*User] Client ID[*Client] Scope string RedirectURI string
}
func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
func (code *AuthCode) Generate() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
return "", fmt.Errorf("failed to generate authentication code secret: %v", err)
}
code = &AuthCode{
Hash: hash,
CreatedAt: time.Now(),
User: user,
Client: client,
Scope: scope,
}
return code, secret, nil
code.Hash = hash code.CreatedAt = time.Now() return secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
return map[string]interface{}{
"id": &code.ID, "hash": &code.Hash, "created_at": &code.CreatedAt, "user": &code.User, "client": &code.Client, "scope": (*nullString)(&code.Scope),
"id": &code.ID, "hash": &code.Hash, "created_at": &code.CreatedAt, "user": &code.User, "client": &code.Client, "scope": (*nullString)(&code.Scope), "redirect_uri": (*nullString)(&code.RedirectURI),
}
}
func (code *AuthCode) VerifySecret(secret string) bool {
return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}
type SecretKind byte
const (
SecretKindAccessToken = SecretKind('a')
SecretKindAuthCode = SecretKind('c')
)
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
kind, s, _ := strings.Cut(s, ".")
idStr, secret, ok := strings.Cut(s, ".")
if !ok || len(kind) != 1 {
return 0, "", fmt.Errorf("malformed secret")
}
switch SecretKind(kind[0]) {
case SecretKindAccessToken:
_, ok = interface{}(id).(ID[*AccessToken])
case SecretKindAuthCode:
_, ok = interface{}(id).(ID[*AuthCode])
}
if !ok {
return 0, "", fmt.Errorf("invalid secret kind %q", kind)
}
id, err = ParseID[T](idStr)
return id, secret, err
}
func MarshalSecret[T entity](id ID[T], secret string) string {
if id == 0 {
panic("cannot marshal zero ID")
}
var kind SecretKind
switch interface{}(id).(type) {
case ID[*AccessToken]:
kind = SecretKindAccessToken
case ID[*AuthCode]:
kind = SecretKindAuthCode
default:
panic(fmt.Sprintf("unsupported secret kind for ID type %T", id))
}
return fmt.Sprintf("%v.%v.%v", string(kind), int64(id), secret)
}
func generateUID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", nil, err
}
secret = base64.RawURLEncoding.EncodeToString(b)
h := sha512.Sum512(b)
return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
b, _ := base64.RawURLEncoding.DecodeString(secret)
h := sha512.Sum512(b)
return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
return time.Now().Before(t)
}
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"mime"
"net"
"net/http"
"net/url"
"strings"
"time"
"git.sr.ht/~emersion/go-oauth2"
)
func getOAuthServerMetadata(w http.ResponseWriter, req *http.Request) {
issuer := getIssuer(req)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&oauth2.ServerMetadata{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
IntrospectionEndpoint: issuer + "/introspect",
RevocationEndpoint: issuer + "/revoke",
ResponseTypesSupported: []oauth2.ResponseType{oauth2.ResponseTypeCode},
ResponseModesSupported: []oauth2.ResponseMode{oauth2.ResponseModeQuery},
GrantTypesSupported: []oauth2.GrantType{oauth2.GrantTypeAuthorizationCode},
TokenEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
IntrospectionEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
RevocationEndpointAuthMethodsSupported: []oauth2.AuthMethod{oauth2.AuthMethodNone, oauth2.AuthMethodClientSecretBasic},
AuthorizationResponseIssParameterSupported: true,
})
}
func getIssuer(req *http.Request) string {
issuerURL := url.URL{
Scheme: "https",
Host: req.Host,
}
if !isForwardedHTTPS(req) && isLoopback(req) {
// TODO: add config option for allowed reverse proxy IPs
issuerURL.Scheme = "http"
}
return issuerURL.String()
}
func isLoopback(req *http.Request) bool {
host, _, _ := net.SplitHostPort(req.RemoteAddr)
ip := net.ParseIP(host)
return ip.IsLoopback()
}
func authorize(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
tpl := templateFromContext(ctx)
q := req.URL.Query()
respType := oauth2.ResponseType(q.Get("response_type"))
clientID := q.Get("client_id")
rawRedirectURI := q.Get("redirect_uri")
scope := q.Get("scope")
state := q.Get("state")
if clientID == "" {
http.Error(w, "Missing client ID", http.StatusBadRequest)
return
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
http.Error(w, "Invalid client ID", http.StatusForbidden)
return
} else if err != nil {
httpError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
var allowedRedirectURIs []*url.URL
for _, s := range strings.Split(client.RedirectURIs, "\n") {
if s == "" {
continue
}
u, err := url.Parse(s)
if err != nil {
httpError(w, fmt.Errorf("failed to parse client redirect URI"))
return
}
allowedRedirectURIs = append(allowedRedirectURIs, u)
}
var redirectURI *url.URL
if rawRedirectURI != "" {
redirectURI, err = url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
if !validateRedirectURI(redirectURI, allowedRedirectURIs) {
http.Error(w, "Invalid redirect URI", http.StatusBadRequest)
return
}
} else {
if len(allowedRedirectURIs) == 0 {
http.Error(w, "Missing redirect URI", http.StatusBadRequest)
return
}
redirectURI = allowedRedirectURIs[0]
}
if respType != oauth2.ResponseTypeCode {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedResponseType,
})
return
}
// TODO: add support for scope
if scope != "" {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidScope,
})
return
}
loginToken := loginTokenFromContext(ctx)
if loginToken == nil {
q := make(url.Values)
q.Set("redirect_uri", req.URL.String())
u := url.URL{
Path: "/login",
RawQuery: q.Encode(),
}
http.Redirect(w, req, u.String(), http.StatusFound)
return
}
_ = req.ParseForm()
if _, ok := req.PostForm["deny"]; ok {
redirectClientError(w, req, redirectURI, state, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
})
return
}
if _, ok := req.PostForm["authorize"]; !ok {
data := struct {
Client *Client
}{
Client: client,
}
if err := tpl.ExecuteTemplate(w, "authorize.html", data); err != nil {
panic(err)
}
return
}
authCode, secret, err := NewAuthCode(loginToken.User, client.ID, scope)
authCode := AuthCode{
User: loginToken.User,
Client: client.ID,
Scope: scope,
RedirectURI: rawRedirectURI,
}
secret, err := authCode.Generate()
if err != nil {
httpError(w, fmt.Errorf("failed to generate authentication code: %v", err))
return
}
if err := db.CreateAuthCode(ctx, authCode); err != nil {
if err := db.CreateAuthCode(ctx, &authCode); err != nil {
httpError(w, fmt.Errorf("failed to create authentication code: %v", err))
return
}
code := MarshalSecret(authCode.ID, secret)
values := make(url.Values)
values.Set("code", code)
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func exchangeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
clientID := values.Get("client_id")
grantType := oauth2.GrantType(values.Get("grant_type"))
scope := values.Get("scope")
redirectURI := values.Get("redirect_uri")
authClientID, clientSecret, _ := req.BasicAuth()
if clientID == "" && authClientID == "" {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Missing client ID",
})
return
} else if clientID == "" {
clientID = authClientID
} else if clientID != authClientID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: "Client ID in request body doesn't match Authorization header field",
})
return
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
if !client.VerifySecret(clientSecret) {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid client secret",
})
return
}
}
if grantType != oauth2.GrantTypeAuthorizationCode {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeUnsupportedGrantType,
Description: "Unsupported grant type",
})
return
}
codeID, codeSecret, _ := UnmarshalSecret[*AuthCode](values.Get("code"))
authCode, err := db.PopAuthCode(ctx, codeID)
if err == errNoDBRows || (err == nil && !authCode.VerifySecret(codeSecret)) || authCode.Client != client.ID {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid authorization code",
})
return
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch authorization code: %v", err))
return
}
if scope != authCode.Scope {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid scope",
})
return
}
// TODO: check redirect_uri
if redirectURI != authCode.RedirectURI {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeAccessDenied,
Description: "Invalid redirect URI",
})
return
}
token, secret, err := NewAccessTokenFromAuthCode(authCode)
if err != nil {
oauthError(w, err)
return
}
if err := db.CreateAccessToken(ctx, token); err != nil {
oauthError(w, fmt.Errorf("failed to create access token: %v", err))
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(&oauth2.TokenResp{
AccessToken: MarshalSecret(token.ID, secret),
TokenType: oauth2.TokenTypeBearer,
ExpiresIn: time.Until(token.ExpiresAt),
Scope: strings.Split(token.Scope, " "),
})
}
func introspectToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
token = nil
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
var resp oauth2.IntrospectionResp
if token != nil {
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
user, err := db.FetchUser(ctx, token.User)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch user: %v", err))
return
}
resp.Active = true
resp.TokenType = oauth2.TokenTypeBearer
resp.ExpiresAt = token.ExpiresAt
resp.IssuedAt = token.IssuedAt
resp.ClientID = client.ClientID
resp.Username = user.Username
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(&resp)
}
func revokeToken(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
db := dbFromContext(ctx)
values, err := parseRequestBody(req)
if err != nil {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidRequest,
Description: err.Error(),
})
return
}
client, err := maybeAuthenticateClient(w, req)
if err != nil {
oauthError(w, err)
return
}
tokenID, secret, _ := UnmarshalSecret[*AccessToken](values.Get("token"))
token, err := db.FetchAccessToken(ctx, tokenID)
if err == errNoDBRows || (err == nil && !token.VerifySecret(secret)) {
return // ignore
} else if err != nil {
oauthError(w, fmt.Errorf("failed to fetch access token: %v", err))
return
}
if client == nil {
client, err = db.FetchClient(ctx, token.Client)
if err != nil {
oauthError(w, fmt.Errorf("failed to fetch client: %v", err))
return
}
if !client.IsPublic() {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Missing client ID and secret",
})
return
}
}
if client.ID != token.Client {
oauthError(w, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
})
return
}
if err := db.DeleteAccessToken(ctx, token.ID); err != nil {
oauthError(w, err)
return
}
}
func parseRequestBody(req *http.Request) (url.Values, error) {
ct := req.Header.Get("Content-Type")
if ct != "" {
mimeType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
if err != nil {
return nil, fmt.Errorf("malformed Content-Type header field")
} else if mimeType != "application/x-www-form-urlencoded" {
return nil, fmt.Errorf("unsupported request content type")
}
}
r := io.LimitReader(req.Body, 10<<20)
b, err := io.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
values, err := url.ParseQuery(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
return values, nil
}
func oauthError(w http.ResponseWriter, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
statusCode := http.StatusInternalServerError
switch oauthErr.Code {
case oauth2.ErrorCodeInvalidRequest, oauth2.ErrorCodeUnsupportedResponseType, oauth2.ErrorCodeInvalidScope, oauth2.ErrorCodeInvalidClient, oauth2.ErrorCodeInvalidGrant, oauth2.ErrorCodeUnsupportedGrantType:
statusCode = http.StatusBadRequest
case oauth2.ErrorCodeUnauthorizedClient, oauth2.ErrorCodeAccessDenied:
statusCode = http.StatusForbidden
case oauth2.ErrorCodeTemporarilyUnavailable:
statusCode = http.StatusServiceUnavailable
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(oauthErr)
}
func redirectClient(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, values url.Values) {
q := redirectURI.Query()
for k, v := range values {
q[k] = v
}
q.Set("iss", getIssuer(req))
u := *redirectURI
u.RawQuery = q.Encode()
http.Redirect(w, req, u.String(), http.StatusFound)
}
func redirectClientError(w http.ResponseWriter, req *http.Request, redirectURI *url.URL, state string, err error) {
var oauthErr *oauth2.Error
if !errors.As(err, &oauthErr) {
oauthErr = &oauth2.Error{Code: oauth2.ErrorCodeServerError}
log.Print(err)
}
values := make(url.Values)
values.Set("error", string(oauthErr.Code))
if oauthErr.Description != "" {
values.Set("error_description", oauthErr.Description)
}
if oauthErr.URI != "" {
values.Set("error_uri", oauthErr.URI)
}
if state != "" {
values.Set("state", state)
}
redirectClient(w, req, redirectURI, values)
}
func validateRedirectURI(u *url.URL, allowedURIs []*url.URL) bool {
// Loopback interface, see RFC 8252 section 7.3
host, _, _ := net.SplitHostPort(u.Host)
ip := net.ParseIP(host)
if u.Scheme == "http" && ip.IsLoopback() {
uu := *u
uu.Host = "localhost"
u = &uu
}
for _, allowed := range allowedURIs {
if u.String() == allowed.String() {
return true
}
}
return false
}
func maybeAuthenticateClient(w http.ResponseWriter, req *http.Request) (*Client, error) {
ctx := req.Context()
db := dbFromContext(ctx)
clientID, clientSecret, ok := req.BasicAuth()
if !ok {
return nil, nil
}
client, err := db.FetchClientByClientID(ctx, clientID)
if err == errNoDBRows || (err == nil && !client.VerifySecret(clientSecret)) {
return nil, &oauth2.Error{
Code: oauth2.ErrorCodeInvalidClient,
Description: "Invalid client ID or secret",
}
} else if err != nil {
return nil, fmt.Errorf("failed to fetch client: %v", err)
}
return client, nil
}
PRAGMA user_version = 1; CREATE TABLE User ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE, password_hash TEXT, admin INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE Client ( id INTEGER PRIMARY KEY, client_id TEXT NOT NULL UNIQUE, client_secret_hash BLOB, owner INTEGER, redirect_uris TEXT, client_name TEXT, client_uri TEXT, FOREIGN KEY(owner) REFERENCES User(id) ); CREATE TABLE AccessToken ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, user INTEGER NOT NULL, client INTEGER, scope TEXT, issued_at datetime NOT NULL, expires_at datetime NOT NULL, FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(client) REFERENCES Client(id) ); CREATE TABLE AuthCode ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, created_at datetime NOT NULL, user INTEGER NOT NULL, client INTEGER NOT NULL,
redirect_uri TEXT,
scope TEXT, FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(client) REFERENCES Client(id) );