Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Add DB migration infrastructure
package main import ( "context" "database/sql" _ "embed" "fmt" "time" "github.com/mattn/go-sqlite3" ) //go:embed schema.sql var schema string
var migrations = []string{
"", // migration #0 is reserved for schema initialization
}
var errNoDBRows = sql.ErrNoRows
type DB struct {
db *sql.DB
}
func openDB(filename string) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
db := &DB{sqlDB}
if err := db.init(context.TODO()); err != nil {
db.Close()
return nil, err
}
return db, nil
}
func (db *DB) init(ctx context.Context) error {
var n int
if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
version, err := db.upgrade(ctx)
if err != nil {
return err
} else if n != 0 {
return nil
}
if _, err := db.db.ExecContext(ctx, schema); err != nil {
return err
if version > 0 {
return nil
}
// TODO: drop this
defaultUser := User{Username: "root", Admin: true}
if err := defaultUser.SetPassword("root"); err != nil {
return err
}
return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) upgrade(ctx context.Context) (version int, err error) {
if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return 0, fmt.Errorf("failed to query schema version: %v", err)
}
if version == len(migrations) {
return version, nil
} else if version > len(migrations) {
return version, fmt.Errorf("sinwon (version %d) older than schema (version %d)", len(migrations), version)
}
tx, err := db.db.Begin()
if err != nil {
return version, err
}
defer tx.Rollback()
if version == 0 {
if _, err := tx.Exec(schema); err != nil {
return version, fmt.Errorf("failed to initialize schema: %v", err)
}
} else {
for i := version; i < len(migrations); i++ {
if _, err := tx.Exec(migrations[i]); err != nil {
return version, fmt.Errorf("failed to execute migration #%v: %v", i, err)
}
}
}
// For some reason prepared statements don't work here
_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
if err != nil {
return version, fmt.Errorf("failed to bump schema version: %v", err)
}
return version, tx.Commit()
}
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO User(id, username, password_hash, admin)
VALUES (:id, :username, :password_hash, :admin)
ON CONFLICT(id) DO UPDATE SET
username = :username,
password_hash = :password_hash,
admin = :admin
RETURNING id
`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) ListUsers(ctx context.Context) ([]User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var l []User
for rows.Next() {
var user User
if err := scan(&user, rows); err != nil {
return nil, err
}
l = append(l, user)
}
return l, rows.Close()
}
func (db *DB) FetchClient(ctx context.Context, id ID[*Client]) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE id = ?", id)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) FetchClientByClientID(ctx context.Context, clientID string) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO Client(id, client_id, client_secret_hash, owner,
redirect_uris, client_name, client_uri)
VALUES (:id, :client_id, :client_secret_hash, :owner,
:redirect_uris, :client_name, :client_uri)
ON CONFLICT(id) DO UPDATE SET
client_id = :client_id,
client_secret_hash = :client_secret_hash,
owner = :owner,
redirect_uris = :redirect_uris,
client_name = :client_name,
client_uri = :client_uri
RETURNING id
`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
if err != nil {
return nil, err
}
defer rows.Close()
var l []Client
for rows.Next() {
var client Client
if err := scan(&client, rows); err != nil {
return nil, err
}
l = append(l, client)
}
return l, rows.Close()
}
func (db *DB) ListAuthorizedClients(ctx context.Context, user ID[*User]) ([]AuthorizedClient, error) {
rows, err := db.db.QueryContext(ctx, `
SELECT id, client_id, client_name, client_uri, token.expires_at
FROM Client,
(
SELECT client, MAX(expires_at) as expires_at
FROM AccessToken
WHERE user = ?
GROUP BY client
) AS token
WHERE Client.id = token.client
`, user)
if err != nil {
return nil, err
}
var l []AuthorizedClient
for rows.Next() {
var authClient AuthorizedClient
columns := authClient.Client.columns()
var expiresAt string
err := rows.Scan(columns["id"], columns["client_id"], columns["client_name"], columns["client_uri"], &expiresAt)
if err != nil {
return nil, err
}
authClient.ExpiresAt, err = time.Parse(sqlite3.SQLiteTimestampFormats[0], expiresAt)
if err != nil {
return nil, err
}
l = append(l, authClient)
}
return l, rows.Close()
}
func (db *DB) DeleteClient(ctx context.Context, id ID[*Client]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM Client WHERE id = ?", id)
return err
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
if err != nil {
return nil, err
}
var token AccessToken
err = scanRow(&token, rows)
return &token, err
}
func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
RETURNING id
`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) DeleteAccessToken(ctx context.Context, id ID[*AccessToken]) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM AccessToken WHERE id = ?", id)
return err
}
func (db *DB) RevokeAccessTokens(ctx context.Context, clientID ID[*Client], userID ID[*User]) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE client = ? AND user = ?
`, clientID, userID)
return err
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AuthCode(hash, created_at, user, client, scope, redirect_uri)
VALUES (:hash, :created_at, :user, :client, :scope, :redirect_uri)
RETURNING id
`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
rows, err := db.db.QueryContext(ctx, `
DELETE FROM AuthCode
WHERE id = ?
RETURNING *
`, id)
if err != nil {
return nil, err
}
var authCode AuthCode
err = scanRow(&authCode, rows)
return &authCode, err
}
func (db *DB) Maintain(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE timediff('now', expires_at) > 0
`)
if err != nil {
return err
}
_, err = db.db.ExecContext(ctx, `
DELETE FROM AuthCode
WHERE timediff(?, created_at) > 0
`, time.Now().Add(-authCodeExpiration))
if err != nil {
return err
}
return nil
}
func scan(e entity, rows *sql.Rows) error {
columns := e.columns()
keys, err := rows.Columns()
if err != nil {
panic(err)
}
out := make([]interface{}, len(keys))
for i, k := range keys {
v, ok := columns[k]
if !ok {
panic(fmt.Errorf("unknown column %q", k))
}
out[i] = v
}
return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
if !rows.Next() {
return sql.ErrNoRows
}
if err := scan(e, rows); err != nil {
return err
}
return rows.Close()
}
func entityArgs(e entity) []interface{} {
columns := e.columns()
l := make([]interface{}, 0, len(columns))
for k, v := range columns {
l = append(l, sql.Named(k, v))
}
return l
}
PRAGMA user_version = 1;
CREATE TABLE User ( id INTEGER PRIMARY KEY, username TEXT NOT NULL UNIQUE, password_hash TEXT, admin INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE Client ( id INTEGER PRIMARY KEY, client_id TEXT NOT NULL UNIQUE, client_secret_hash BLOB, owner INTEGER, redirect_uris TEXT, client_name TEXT, client_uri TEXT, FOREIGN KEY(owner) REFERENCES User(id) ); CREATE TABLE AccessToken ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, user INTEGER NOT NULL, client INTEGER, scope TEXT, issued_at datetime NOT NULL, expires_at datetime NOT NULL, FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(client) REFERENCES Client(id) ); CREATE TABLE AuthCode ( id INTEGER PRIMARY KEY, hash BLOB NOT NULL UNIQUE, created_at datetime NOT NULL, user INTEGER NOT NULL, client INTEGER NOT NULL, redirect_uri TEXT, scope TEXT, FOREIGN KEY(user) REFERENCES User(id), FOREIGN KEY(client) REFERENCES Client(id) );