Lindenii Project Forge
Login

hare-ds

Data structures for Hare
Commit info
ID
956e56b5bde140ce7ef54f7e0aaa9725e6ab3f2b
Author
Runxi Yu <me@runxiyu.org>
Author date
Tue, 16 Sep 2025 19:38:54 +0800
Committer
Runxi Yu <me@runxiyu.org>
Committer date
Tue, 16 Sep 2025 19:38:54 +0800
Actions
Add map_rbtree

Blindly translated from an old Go implementation I had laying around...
map_rbtree: key-value map implemented as a red–black tree
// Deletes an item from a [[map]]. Returns the removed value or void.
export fn del(m: *map, key: []u8) (*opaque | void) = {
	let z = find_node(m, key);
	match (z) {
	case null => return;
	case let nodez: *node =>
		let ret = nodez.val;

		let y = nodez;
		let y_orig = y.color;

		let x: nullable *node = null;
		let p_for_fix: nullable *node = null;

		if (nodez.left == null) {
			x = nodez.right;
			p_for_fix = nodez.parent;
			transplant(m, nodez, nodez.right);
		} else if (nodez.right == null) {
			x = nodez.left;
			p_for_fix = nodez.parent;
			transplant(m, nodez, nodez.left);
		} else {
			let r = match (nodez.right) {
			case let rr: *node => yield rr;
			case null => abort("rb invariant violated: del: right is null");
			};
			let s = subtree_min(r);
			y = s;
			let yor = y.color;
			y_orig = yor;

			x = y.right;
			if (y.parent == (nodez: nullable *node)) {
				p_for_fix = y;
				set_parent(x, y);
			} else {
				p_for_fix = y.parent;
				transplant(m, y, y.right);
				y.right = nodez.right;
				set_parent(y.right, y);
			};

			transplant(m, nodez, y);
			y.left = nodez.left;
			set_parent(y.left, y);
			y.color = nodez.color;
		};

		free(nodez);

		if (y_orig == color::BLACK) {
			delete_fixup(m, x, p_for_fix);
		};

		return ret;
	};
};
fn free_subtree(n: nullable *node) void = {
	match (n) {
	case null => return;
	case let p: *node =>
		free_subtree(p.left);
		free_subtree(p.right);
		free(p);
	};
};

// Frees resources associated with a [[map]].
export fn finish(m: *map) void = {
	free_subtree(m.root);
	free(m);
};
// SPDX-License-Identifier: MPL-2.0

use bytes;

// Gets an item from a [[map]] by key, returning void if not found.
export fn get(m: *map, key: []u8) (*opaque | void) = {
	let n = find_node(m, key);
	match (n) {
	case null => return;
	case let p: *node => return p.val;
	};
};
use bytes;

fn keycmp(a: []u8, b: []u8) int = {
	let n = if (len(a) < len(b)) len(a) else len(b);
	for (let i = 0z; i < n; i += 1) {
		if (a[i] < b[i]) return -1;
		if (a[i] > b[i]) return 1;
	};
	if (len(a) < len(b)) return -1;
	if (len(a) > len(b)) return 1;
	return 0;
};

fn is_red(p: nullable *node) bool = {
	match (p) {
	case null => return false;
	case let n: *node => return n.color == color::RED;
	};
};

fn is_black(p: nullable *node) bool = !is_red(p);

fn set_color(p: nullable *node, c: color) void = {
	match (p) {
	case null => return;
	case let n: *node => n.color = c;
	};
};

fn set_parent(ch: nullable *node, pa: nullable *node) void = {
	match (ch) {
	case null => return;
	case let n: *node => n.parent = pa;
	};
};

fn subtree_min(n: *node) *node = {
	let cur = n;
	for (true) {
		match (cur.left) {
		case null => return cur;
		case let l: *node => cur = l;
		};
	};
};

fn transplant(m: *map, u: *node, v: nullable *node) void = {
	match (u.parent) {
	case null =>
		m.root = v;
	case let p: *node =>
		if (p.left == (u: nullable *node)) {
			p.left = v;
		} else {
			p.right = v;
		};
	};
	set_parent(v, u.parent);
};

fn rotate_left(m: *map, x: *node) void = {
	let y = match (x.right) {
	case let r: *node => yield r;
	case null => abort("rb invariant violated: rotate_left with null right");
	};

	x.right = y.left;
	set_parent(x.right, x);

	y.parent = x.parent;
	match (x.parent) {
	case null =>
		m.root = y;
	case let p: *node =>
		if (p.left == (x: nullable *node)) {
			p.left = y;
		} else {
			p.right = y;
		};
	};

	y.left = x;
	x.parent = y;
};

fn rotate_right(m: *map, x: *node) void = {
	let y = match (x.left) {
	case let l: *node => yield l;
	case null => abort("rb invariant violated: rotate_right with null left");
	};

	x.left = y.right;
	set_parent(x.left, x);

	y.parent = x.parent;
	match (x.parent) {
	case null =>
		m.root = y;
	case let p: *node =>
		if (p.left == (x: nullable *node)) {
			p.left = y;
		} else {
			p.right = y;
		};
	};

	y.right = x;
	x.parent = y;
};

fn insert_fixup(m: *map, z: *node) void = {
	let cur = z;
	for (true) {
		let p = match (cur.parent) {
		case null => break;
		case let pp: *node => yield pp;
		};
		if (p.color == color::BLACK) break;

		let gp = match (p.parent) {
		case null => break;
		case let g: *node => yield g;
		};

		let uncle: nullable *node = if (gp.left == (p: nullable *node))
			gp.right else gp.left;

		if (is_red(uncle)) {
			set_color(p, color::BLACK);
			set_color(uncle, color::BLACK);
			set_color(gp, color::RED);
			cur = gp;
			continue;
		};

		if (gp.left == (p: nullable *node)) {
			if (p.right == (cur: nullable *node)) {
				rotate_left(m, p);
				cur = p;
			};
			let p2 = match (cur.parent) {
			case null => break;
			case let pp2: *node => yield pp2;
			};
			let g2 = match (p2.parent) {
			case null => break;
			case let gg2: *node => yield gg2;
			};
			set_color(p2, color::BLACK);
			set_color(g2, color::RED);
			rotate_right(m, g2);
		} else {
			if (p.left == (cur: nullable *node)) {
				rotate_right(m, p);
				cur = p;
			};
			let p2 = match (cur.parent) {
			case null => break;
			case let pp2: *node => yield pp2;
			};
			let g2 = match (p2.parent) {
			case null => break;
			case let gg2: *node => yield gg2;
			};
			set_color(p2, color::BLACK);
			set_color(g2, color::RED);
			rotate_left(m, g2);
		};
	};
	match (m.root) {
	case null => void;
	case let r: *node => r.color = color::BLACK;
	};
};

fn find_node(m: *map, key: []u8) nullable *node = {
	let cur = m.root;
	for (true) {
		match (cur) {
		case null => return null;
		case let n: *node =>
			let c = keycmp(key, n.key);
			if (c == 0) return n;
			cur = if (c < 0) n.left else n.right;
		};
	};
};

fn delete_fixup(m: *map, x0: nullable *node, p0: nullable *node) void = {
	let x = x0;
	let p = p0;

	for (x != m.root && is_black(x)) {
		let pp = match (p) {
		case null => break;
		case let q: *node => yield q;
		};

		if (pp.left == x) {
			let mutw = pp.right;
			let w = match (mutw) {
			case null => break;
			case let w_: *node => yield w_;
			};
			if (is_red(w)) {
				set_color(w, color::BLACK);
				set_color(pp, color::RED);
				rotate_left(m, pp);
				mutw = pp.right;
			};
			let wr = match (mutw) {
			case null => yield null;
			case let w2: *node => yield w2.right;
			};
			let wl = match (mutw) {
			case null => yield null;
			case let w2: *node => yield w2.left;
			};
			if (is_black(wl) && is_black(wr)) {
				set_color(mutw, color::RED);
				x = pp;
				p = pp.parent;
			} else {
				if (is_black(wr)) {
					set_color(wl, color::BLACK);
					set_color(mutw, color::RED);
					match (mutw) {
					case null => void;
					case let ww: *node => rotate_right(m, ww);
					};
					mutw = pp.right;
				};
				match (mutw) {
				case null => void;
				case let w3: *node =>
					w3.color = pp.color;
					set_color(pp, color::BLACK);
					set_color(w3.right, color::BLACK);
					rotate_left(m, pp);
				};
				x = m.root;
				p = null;
			};
		} else {
			let mutw = pp.left;
			let w = match (mutw) {
			case null => break;
			case let w_: *node => yield w_;
			};
			if (is_red(w)) {
				set_color(w, color::BLACK);
				set_color(pp, color::RED);
				rotate_right(m, pp);
				mutw = pp.left;
			};
			let wl = match (mutw) {
			case null => yield null;
			case let w2: *node => yield w2.left;
			};
			let wr = match (mutw) {
			case null => yield null;
			case let w2: *node => yield w2.right;
			};
			if (is_black(wl) && is_black(wr)) {
				set_color(mutw, color::RED);
				x = pp;
				p = pp.parent;
			} else {
				if (is_black(wl)) {
					set_color(wr, color::BLACK);
					set_color(mutw, color::RED);
					match (mutw) {
					case null => void;
					case let ww: *node => rotate_left(m, ww);
					};
					mutw = pp.left;
				};
				match (mutw) {
				case null => void;
				case let w3: *node =>
					w3.color = pp.color;
					set_color(pp, color::BLACK);
					set_color(w3.left, color::BLACK);
					rotate_right(m, pp);
				};
				x = m.root;
				p = null;
			};
		};
	};
	set_color(x, color::BLACK);
};
use ds::map;

// Red–black tree-based map from []u8 to *opaque.
//
// You are advised to create these with [[new]].
export type map = struct {
	vt: map::map,
	root: nullable *node,
};

const _vt: map::vtable = map::vtable {
	getter   = &vt_get,
	setter   = &vt_set,
	deleter  = &vt_del,
	finisher = &vt_finish,
};

fn vt_get(m: *map::map, key: []u8) (*opaque | void) = get(m: *map, key);
fn vt_set(m: *map::map, key: []u8, v: *opaque) (void | nomem) = set(m: *map, key, v);
fn vt_del(m: *map::map, key: []u8) (*opaque | void) = del(m: *map, key);
fn vt_finish(m: *map::map) void = finish(m: *map);
// Creates a new [[map]].
export fn new() (*map | nomem) = {
	let m = alloc(map {
		vt = &_vt,
		root = null,
	})?;
	return m;
};
// SPDX-License-Identifier: MPL-2.0

export type color = enum u8 {
	RED = 0,
	BLACK = 1,
};

export type node = struct {
	color: color,
	key: []u8,
	val: *opaque,
	left: nullable *node,
	right: nullable *node,
	parent: nullable *node,
};
use bytes;

export fn set(m: *map, key: []u8, value: *opaque) (void | nomem) = {
	match (find_node(m, key)) {
	case let ex: *node =>
		ex.val = value;
		return;
	case null => void;
	};

	let z = alloc(node {
		color = color::RED,
		key = key,
		val = value,
		left = null,
		right = null,
		parent = null,
	})?;

	let y: nullable *node = null;
	let x = m.root;

	for (true) {
		match (x) {
		case null => break;
		case let xn: *node =>
			y = xn;
			if (keycmp(z.key, xn.key) < 0) {
				x = xn.left;
			} else {
				x = xn.right;
			};
		};
	};

	z.parent = y;
	match (y) {
	case null =>
		m.root = z;
	case let yn: *node =>
		if (keycmp(z.key, yn.key) < 0) {
			yn.left = z;
		} else {
			yn.right = z;
		};
	};

	insert_fixup(m, z);
};
use errors;
use strings;
use ds::map;

@test fn roundtrip() void = {
	const key: [16]u8 = [0...];

	let m: *map = match (new()) {
	case let p: *map => yield p;
	case nomem => abort("unexpected nomem");
	};
	defer finish(m);

	let v1 = 1, v2 = 2, v3 = 3;
	let p1: *opaque = (&v1: *opaque);
	let p2: *opaque = (&v2: *opaque);
	let p3: *opaque = (&v3: *opaque);

	let k1 = strings::toutf8("alpha");
	let k2 = strings::toutf8("beta");
	let k3 = strings::toutf8("gamma");

	match (map::set(m, k1, p1)) {
	case void => yield;
	case nomem => abort("unexpected nomem in set(k1,p1)");
	};

	match (map::get(m, k1)) {
	case let got: *opaque =>
		assert(got == p1, "get(k1) must return p1");
	case void =>
		abort("get(k1) unexpectedly void");
	};

	match (map::set(m, k1, p2)) {
	case void => yield;
	case nomem => abort("unexpected nomem in replace");
	};
	match (map::get(m, k1)) {
	case let got: *opaque =>
		assert(got == p2, "replace must overwrite prior value");
	case void =>
		abort("get(k1) void after replace");
	};

	match (map::set(m, k2, p3)) {
	case void => yield;
	case nomem => abort("unexpected nomem in set(k2,p3)");
	};

	match (map::get(m, k3)) {
	case void => yield;
	case *opaque =>
		abort("get(k3) must be void for missing key");
	};

	match (map::del(m, k2)) {
	case let got: *opaque =>
		assert(got == p3, "del(k2) must return stored value");
	case void =>
		abort("del(k2) unexpectedly void");
	};
	match (map::del(m, k2)) {
	case void => yield;
	case *opaque =>
		abort("del(k2) must be void after prior delete");
	};
};