Hi… I am well aware that this diff view is very suboptimal. It will be fixed when the refactored server comes along!
Cleanup expired tokens and codes from DB Closes: https://todo.sr.ht/~emersion/sinwon/4
package main import ( "context" "database/sql" _ "embed" "fmt"
"time"
_ "github.com/mattn/go-sqlite3"
)
//go:embed schema.sql
var schema string
var errNoDBRows = sql.ErrNoRows
type DB struct {
db *sql.DB
}
func openDB(filename string) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
db := &DB{sqlDB}
if err := db.init(context.TODO()); err != nil {
db.Close()
return nil, err
}
return db, nil
}
func (db *DB) init(ctx context.Context) error {
var n int
if err := db.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sqlite_schema").Scan(&n); err != nil {
return err
} else if n != 0 {
return nil
}
if _, err := db.db.ExecContext(ctx, schema); err != nil {
return err
}
// TODO: drop this
defaultUser := User{Username: "root"}
if err := defaultUser.SetPassword("root"); err != nil {
return err
}
return db.StoreUser(ctx, &defaultUser)
}
func (db *DB) Close() error {
return db.db.Close()
}
func (db *DB) FetchUser(ctx context.Context, id ID[*User]) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE id = ?", id)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM User WHERE username = ?", username)
if err != nil {
return nil, err
}
var user User
err = scanRow(&user, rows)
return &user, err
}
func (db *DB) StoreUser(ctx context.Context, user *User) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO User(id, username, password_hash)
VALUES (:id, :username, :password_hash)
ON CONFLICT(id) DO UPDATE SET
username = :username,
password_hash = :password_hash
RETURNING id
`, entityArgs(user)...).Scan(&user.ID)
}
func (db *DB) FetchClient(ctx context.Context, clientID string) (*Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE client_id = ?", clientID)
if err != nil {
return nil, err
}
var client Client
err = scanRow(&client, rows)
return &client, err
}
func (db *DB) StoreClient(ctx context.Context, client *Client) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO Client(id, client_id, client_secret_hash, owner)
VALUES (:id, :client_id, :client_secret_hash, :owner)
ON CONFLICT(id) DO UPDATE SET
client_id = :client_id,
client_secret_hash = :client_secret_hash,
owner = :owner
RETURNING id
`, entityArgs(client)...).Scan(&client.ID)
}
func (db *DB) ListClients(ctx context.Context, owner ID[*User]) ([]Client, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM Client WHERE owner IS ?", owner)
if err != nil {
return nil, err
}
defer rows.Close()
var l []Client
for rows.Next() {
var client Client
if err := scan(&client, rows); err != nil {
return nil, err
}
l = append(l, client)
}
return l, rows.Close()
}
func (db *DB) FetchAccessToken(ctx context.Context, id ID[*AccessToken]) (*AccessToken, error) {
rows, err := db.db.QueryContext(ctx, "SELECT * FROM AccessToken WHERE id = ?", id)
if err != nil {
return nil, err
}
var token AccessToken
err = scanRow(&token, rows)
return &token, err
}
func (db *DB) CreateAccessToken(ctx context.Context, token *AccessToken) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AccessToken(hash, user, client, scope, issued_at, expires_at)
VALUES (:hash, :user, :client, :scope, :issued_at, :expires_at)
RETURNING id
`, entityArgs(token)...).Scan(&token.ID)
}
func (db *DB) CreateAuthCode(ctx context.Context, code *AuthCode) error {
return db.db.QueryRowContext(ctx, `
INSERT INTO AuthCode(hash, created_at, user, client, scope)
VALUES (:hash, :created_at, :user, :client, :scope)
RETURNING id
`, entityArgs(code)...).Scan(&code.ID)
}
func (db *DB) PopAuthCode(ctx context.Context, id ID[*AuthCode]) (*AuthCode, error) {
rows, err := db.db.QueryContext(ctx, `
DELETE FROM AuthCode
WHERE id = ?
RETURNING *
`, id)
if err != nil {
return nil, err
}
var authCode AuthCode
err = scanRow(&authCode, rows)
return &authCode, err
}
func (db *DB) Maintain(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM AccessToken
WHERE timediff('now', expires_at) > 0
`)
if err != nil {
return err
}
_, err = db.db.ExecContext(ctx, `
DELETE FROM AuthCode
WHERE timediff(?, created_at) > 0
`, time.Now().Add(-authCodeExpiration))
if err != nil {
return err
}
return nil
}
func scan(e entity, rows *sql.Rows) error {
columns := e.columns()
keys, err := rows.Columns()
if err != nil {
panic(err)
}
out := make([]interface{}, len(keys))
for i, k := range keys {
v, ok := columns[k]
if !ok {
panic(fmt.Errorf("unknown column %q", k))
}
out[i] = v
}
return rows.Scan(out...)
}
func scanRow(e entity, rows *sql.Rows) error {
if !rows.Next() {
return sql.ErrNoRows
}
if err := scan(e, rows); err != nil {
return err
}
return rows.Close()
}
func entityArgs(e entity) []interface{} {
columns := e.columns()
l := make([]interface{}, 0, len(columns))
for k, v := range columns {
l = append(l, sql.Named(k, v))
}
return l
}
package main import ( "crypto/rand" "crypto/sha512" "crypto/subtle" "database/sql" "database/sql/driver" "encoding/base64" "fmt" "strconv" "strings" "time" "golang.org/x/crypto/bcrypt" )
const authCodeExpiration = 10 * time.Minute
type entity interface {
columns() map[string]interface{}
}
var (
_ entity = (*User)(nil)
_ entity = (*Client)(nil)
_ entity = (*AccessToken)(nil)
_ entity = (*AuthCode)(nil)
)
type ID[T entity] int64
var (
_ sql.Scanner = (*ID[*User])(nil)
_ driver.Valuer = ID[*User](0)
)
func ParseID[T entity](s string) (ID[T], error) {
u, _ := strconv.ParseUint(s, 10, 63)
if u == 0 {
return 0, fmt.Errorf("invalid ID")
}
return ID[T](u), nil
}
func (ptr *ID[T]) Scan(v interface{}) error {
if v == nil {
*ptr = 0
return nil
}
id, ok := v.(int64)
if !ok {
return fmt.Errorf("cannot scan ID from %T", v)
}
*ptr = ID[T](id)
return nil
}
func (id ID[T]) Value() (driver.Value, error) {
if id == 0 {
return nil, nil
} else {
return int64(id), nil
}
}
type User struct {
ID ID[*User]
Username string
PasswordHash string
}
func (user *User) columns() map[string]interface{} {
return map[string]interface{}{
"id": &user.ID,
"username": &user.Username,
"password_hash": &user.PasswordHash,
}
}
func (user *User) VerifyPassword(password string) error {
// TODO: upgrade hash
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
func (user *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hash)
return nil
}
type Client struct {
ID ID[*Client]
ClientID string
ClientSecretHash []byte
Owner ID[*User]
}
func NewClient(owner ID[*User]) (client *Client, secret string, err error) {
id, err := generateUID()
if err != nil {
return nil, "", fmt.Errorf("failed to generate client ID: %v", err)
}
secret, hash, err := generateSecret()
if err != nil {
return nil, "", fmt.Errorf("failed to generate client secret: %v", err)
}
client = &Client{
ClientID: id,
ClientSecretHash: hash,
Owner: owner,
}
return client, secret, nil
}
func (client *Client) columns() map[string]interface{} {
return map[string]interface{}{
"id": &client.ID,
"client_id": &client.ClientID,
"client_secret_hash": &client.ClientSecretHash,
"owner": &client.Owner,
}
}
func (client *Client) VerifySecret(secret string) bool {
return verifyHash(client.ClientSecretHash, secret)
}
type AccessToken struct {
ID ID[*AccessToken]
Hash []byte
User ID[*User]
Client ID[*Client]
Scope string
IssuedAt time.Time
ExpiresAt time.Time
}
func (token *AccessToken) Generate() (secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return "", fmt.Errorf("failed to generate access token secret: %v", err)
}
token.Hash = hash
token.IssuedAt = time.Now()
token.ExpiresAt = time.Now().Add(2 * time.Hour)
return secret, nil
}
func NewAccessTokenFromAuthCode(authCode *AuthCode) (token *AccessToken, secret string, err error) {
token = &AccessToken{
User: authCode.User,
Client: authCode.Client,
Scope: authCode.Scope,
}
secret, err = token.Generate()
return token, secret, err
}
func (token *AccessToken) columns() map[string]interface{} {
return map[string]interface{}{
"id": &token.ID,
"hash": &token.Hash,
"user": &token.User,
"client": &token.Client,
"scope": &token.Scope,
"issued_at": &token.IssuedAt,
"expires_at": &token.ExpiresAt,
}
}
func (token *AccessToken) VerifySecret(secret string) bool {
return verifyHash(token.Hash, secret) && verifyExpiration(token.ExpiresAt)
}
type AuthCode struct {
ID ID[*AuthCode]
Hash []byte
CreatedAt time.Time
User ID[*User]
Client ID[*Client]
Scope string
}
func NewAuthCode(user ID[*User], client ID[*Client], scope string) (code *AuthCode, secret string, err error) {
secret, hash, err := generateSecret()
if err != nil {
return nil, "", fmt.Errorf("failed to generate authentication code secret: %v", err)
}
code = &AuthCode{
Hash: hash,
CreatedAt: time.Now(),
User: user,
Client: client,
Scope: scope,
}
return code, secret, nil
}
func (code *AuthCode) columns() map[string]interface{} {
return map[string]interface{}{
"id": &code.ID,
"hash": &code.Hash,
"created_at": &code.CreatedAt,
"user": &code.User,
"client": &code.Client,
"scope": &code.Scope,
}
}
func (code *AuthCode) VerifySecret(secret string) bool {
return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(10*time.Minute))
return verifyHash(code.Hash, secret) && verifyExpiration(code.CreatedAt.Add(authCodeExpiration))
}
func UnmarshalSecret[T entity](s string) (id ID[T], secret string, err error) {
idStr, secret, _ := strings.Cut(s, ".")
id, err = ParseID[T](idStr)
return id, secret, err
}
func MarshalSecret[T entity](id ID[T], secret string) string {
if id == 0 {
panic("cannot marshal zero ID")
}
return fmt.Sprintf("%v.%v", int64(id), secret)
}
func generateUID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func generateSecret() (secret string, hash []byte, err error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", nil, err
}
secret = base64.RawURLEncoding.EncodeToString(b)
h := sha512.Sum512(b)
return secret, h[:], nil
}
func verifyHash(hash []byte, secret string) bool {
b, _ := base64.RawURLEncoding.DecodeString(secret)
h := sha512.Sum512(b)
return subtle.ConstantTimeCompare(hash, h[:]) == 1
}
func verifyExpiration(t time.Time) bool {
return time.Now().Before(t)
}
package main import ( "context" "embed" "flag" "html/template" "log" "net" "net/http"
"time"
"github.com/go-chi/chi/v5"
)
var (
//go:embed template
templateFS embed.FS
//go:embed static
staticFS embed.FS
)
func main() {
var configFilename, listenAddr string
flag.StringVar(&configFilename, "config", "/etc/sinwon/config", "Configuration filename")
flag.StringVar(&listenAddr, "listen", ":8080", "HTTP listen address")
flag.Parse()
cfg, err := loadConfig(configFilename)
if err != nil {
log.Fatalf("Failed to load config file: %v", err)
}
if listenAddr == "" {
listenAddr = cfg.Listen
}
if listenAddr == "" {
log.Fatalf("Missing listen configuration")
}
if cfg.Database == "" {
log.Fatalf("Missing database configuration")
}
db, err := openDB(cfg.Database)
if err != nil {
log.Fatalf("Failed to open DB: %v", err)
}
tpl := template.Must(template.ParseFS(templateFS, "template/*.html"))
mux := chi.NewRouter()
mux.Handle("/static/*", http.FileServer(http.FS(staticFS)))
mux.Get("/", index)
mux.Post("/client/new", createClient)
mux.HandleFunc("/login", login)
mux.Post("/logout", logout)
mux.HandleFunc("/user/new", updateUser)
mux.HandleFunc("/user/{id}", updateUser)
mux.HandleFunc("/authorize", authorize)
mux.Post("/token", exchangeToken)
go maintainDBLoop(db)
server := http.Server{
Addr: listenAddr,
Handler: loginTokenMiddleware(mux),
BaseContext: func(net.Listener) context.Context {
return newBaseContext(db, tpl)
},
}
log.Printf("OAuth server listening on %v", server.Addr)
if err := server.ListenAndServe(); err != nil {
log.Fatalf("Failed to listen and serve: %v", err)
}
}
func httpError(w http.ResponseWriter, err error) {
log.Print(err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
func maintainDBLoop(db *DB) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for range ticker.C {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
if err := db.Maintain(ctx); err != nil {
log.Printf("Failed to perform database maintenance: %v", err)
}
cancel()
}
}