Lindenii Project Forge
Login

hare-ev

Temporary fork of hare-ev for... reasons
Commit info
ID
1cfc0d1127e9b71aace5ec5360c107af529a4e24
Author
Drew DeVault <sir@cmpwn.com>
Author date
Mon, 01 Apr 2024 15:26:30 +0200
Committer
Drew DeVault <sir@cmpwn.com>
Committer date
Mon, 01 Apr 2024 15:26:30 +0200
Actions
all: for-each updates

Signed-off-by: Drew DeVault <sir@cmpwn.com>
use errors;
use io;
use rt;
use time;
use types;
use unix::signal;

// Dispatch callback. See [[ondispatch]].
export type dispatchcb = fn(loop: *loop, user: nullable *opaque) void;

export type ondispatch = struct {
	cb: *dispatchcb,
	user: nullable *opaque,
	loop: *loop,
};

export type loop = struct {
	fd: io::file,
	events: []rt::epoll_event,
	dispatch: []*ondispatch,
	stop: bool,
};

// Creates a new event loop. The user must pass the return value to [[finish]]
// to free associated resources when done using the loop.
export fn newloop() (loop | errors::error) = {
	const fd = match (rt::epoll_create1(rt::EPOLL_CLOEXEC)) {
	case let fd: int =>
		yield fd: io::file;
	case let err: rt::errno =>
		return errors::errno(err);
	};

	return loop {
		fd = fd,
		// XXX: Should the number of events be customizable?
		events = alloc([rt::epoll_event {
			events = 0,
			data = rt::epoll_data {
				fd = 0,
			}
		}...], 256),
		dispatch = [],
		stop = false,
	};
};

// Frees resources associated with an event loop. Must only be called once per
// event loop object. Calling finish invalidates all I/O objects associated with
// the event loop.
export fn finish(loop: *loop) void = {
	free(loop.events);
	io::close(loop.fd)!;
};

// Returns an [[io::file]] for this event loop which can be polled on when
// events are available for processing, for chaining together different event
// loops. The exact semantics of this function are platform-specific, and it may
// not be available for all implementations.
export fn loop_file(loop: *loop) io::file = {
	return loop.fd;
};

// Registers a callback to be invoked before the event loop dispatches pending
// I/O requests. The callback may schedule additional I/O requests to be
// processed in this batch.
export fn do(
	loop: *loop,
	cb: *dispatchcb,
	user: nullable *opaque,
) req = {
	const dispatch = alloc(ondispatch {
		cb = cb,
		user = user,
		loop = loop,
	});
	append(loop.dispatch, dispatch);
	return mkreq(&do_cancel, dispatch);
};

fn do_cancel(req: *req) void = {
	const dispatch = req.user: *ondispatch;
	const loop = dispatch.loop;
	for (let i = 0z; i < len(loop.dispatch); i += 1) {
		if (loop.dispatch[i] == dispatch) {
			delete(loop.dispatch[i]);
			break;
		};
	};
	free(dispatch);
};

// Dispatches the event loop, waiting for new events and calling their callbacks
// as appropriate.
//
// A timeout of -1 will block indefinitely until the next event occurs. A
// timeout of 0 will cause dispatch to return immediately if no events are
// available to process. Portable use of the timeout argument supports only
// millisecond granularity of up to 24 days ([[types::INT_MAX]] milliseconds).
// Negative values other than -1 will cause the program to abort.
//
// Returns false if the loop has been stopped via [[stop]], or true otherwise.
export fn dispatch(
	loop: *loop,
	timeout: time::duration,
) (bool | errors::error) = {
	const millis: int = if (timeout == -1) {
		yield -1;
	} else if (timeout < 0) {
		abort("ev::dispatch: invalid timeout");
	} else {
		yield (timeout / time::MILLISECOND): int;
	};
	if (loop.stop) {
		return false;
	};

	let todo = loop.dispatch;
	loop.dispatch = [];
	for (let i = 0z; i < len(todo); i += 1) {
		const dispatch = todo[i];
	for (let dispatch .. todo) {
		dispatch.cb(loop, dispatch.user);
		free(dispatch);
	};
	free(todo);

	if (len(loop.events) == 0) {
		return true;
	};

	// TODO: Deal with signals
	const maxev = len(loop.events);
	assert(maxev <= types::INT_MAX: size, "ev::dispatch: too many events");
	const nevent = match(rt::epoll_pwait(
		loop.fd, &loop.events[0],
		maxev: int, millis, null)) {
	case let nevent: int =>
		yield nevent;
	case let err: rt::errno =>
		switch (err) {
		case rt::EINTR =>
			// We shallow system suspension error code
			return true;
		case =>
			abort("ev::dispatch: epoll_pwait failure");
		};
	};

	for (let i = 0; i < nevent; i += 1) {
		const ev = &loop.events[i];
	for (let ev &.. loop.events) {
		const file = ev.data.ptr: *file;
		if (ev.events == 0) {
			continue;
		};
		const pending = file.op;
		if (ev.events & (rt::EPOLLIN | rt::EPOLLHUP) != 0
				&& file.op & op::READV != 0) {
			readv_ready(file, ev);
		};
		if (ev.events & (rt::EPOLLOUT | rt::EPOLLHUP) != 0
				&& file.op & op::WRITEV != 0) {
			writev_ready(file, ev);
		};
		switch (pending) {
		case op::NONE =>
			abort("No operation pending for ready object");
		case op::READABLE =>
			readable_ready(file, ev);
		case op::WRITABLE =>
			writable_ready(file, ev);
		case op::ACCEPT =>
			accept_ready(file, ev);
		case op::CONNECT_TCP =>
			connect_tcp_ready(file, ev);
		case op::CONNECT_UNIX =>
			connect_unix_ready(file, ev);
		case op::SIGNAL =>
			signal_ready(file, ev);
		case op::TIMER =>
			timer_ready(file, ev);
		case op::SENDTO =>
			sendto_ready(file, ev);
		case op::RECVFROM =>
			recvfrom_ready(file, ev);
		case op::SEND =>
			send_ready(file, ev);
		case op::RECV =>
			recv_ready(file, ev);
		case =>
			assert(pending & ~(op::READV | op::WRITEV) == 0);
		};
	};

	return !loop.stop;
};

// Signals the loop to stop processing events. If called during a callback, it
// will cause that invocation of [[dispatch]] to return false. Otherwise, false
// will be returned only upon the next call to [[dispatch]].
export fn stop(loop: *loop) void = {
	loop.stop = true;
};
use ev;
use fmt;
use net;
use net::dial;
use net::ip;
use net::uri;

// Callback for a [[dial]] operation.
export type dialcb = fn(user: nullable *opaque, r: (*ev::file | error)) void;

// Dials a remote address, establishing a connection and returning the resulting
// [[net::socket]] to the callback. The proto parameter should be the transport
// protocol (e.g. "tcp"), the address parameter should be the remote address,
// and the service should be the name of the service, or the default port to
// use.
//
// See also [[net::dial::dial]].
export fn dial(
	loop: *ev::loop,
	proto: str,
	address: str,
	service: str,
	cb: *dialcb,
	user: nullable *opaque,
) (ev::req | error) = {
	for (let i = 0z; i < len(default_protocols); i += 1) {
		const p = default_protocols[i];
	for (const p .. default_protocols) {
		if (p.name == proto) {
			return p.dial(loop, address, service, cb, user);
		};
	};
	for (let i = 0z; i < len(protocols); i += 1) {
		const p = protocols[i];
	for (const p .. protocols) {
		if (p.name == proto) {
			return p.dial(loop, address, service, cb, user);
		};
	};
	return net::unknownproto: net::error;
};

def HOST_MAX: size = 255;

// Performs a [[dial]] operation for a given URI, taking the service name from
// the URI scheme and forming an address from the URI host and port.
//
// See also [[net::dial::uri]].
export fn dial_uri(
	loop: *ev::loop,
	proto: str,
	uri: *uri::uri,
	cb: *dialcb,
	user: nullable *opaque,
) (ev::req | error) = {
	if (uri.host is str && len(uri.host as str) > HOST_MAX) {
		return invalid_address;
	};
	static let addr: [HOST_MAX + len("[]:65535")]u8 = [0...];

	const colon = if (uri.port != 0) ":" else "";
	const port: fmt::formattable = if (uri.port != 0) uri.port else "";

	let addr = match (uri.host) {
	case let host: str =>
		yield fmt::bsprintf(addr, "{}{}{}", host, colon, port);
	case let ip: ip::addr4 =>
		const host = ip::string(ip);
		yield fmt::bsprintf(addr, "{}{}{}", host, colon, port);
	case let ip: ip::addr6 =>
		const host = ip::string(ip);
		yield fmt::bsprintf(addr, "[{}]{}{}", host, colon, port);
	};

	return dial(loop, proto, addr, uri.scheme, cb, user);
};
use crypto::random;
use errors;
use ev;
use edns = ev::dns;
use net;
use net::ip;
use net::dial;
use net::dns;
use unix::hosts;

// Callback from a [[resolve]] operation.
export type resolvecb = fn(
	user: nullable *opaque,
	r: (([]ip::addr, u16) | error),
) void;

type resolve_state = struct {
	user: nullable *opaque,
	cb: *resolvecb,
	r4: ev::req,
	r6: ev::req,
	nq: uint,
	ip: []ip::addr,
	port: u16,
};

// Performs DNS resolution on a given address string for a given service,
// including /etc/hosts lookup and SRV resolution, and returns a list of
// candidate IP addresses and the appropriate port, or an error, to the
// callback.
//
// The caller must free the [[net::ip::addr]] slice.
export fn resolve(
	loop: *ev::loop,
	proto: str,
	addr: str,
	service: str,
	cb: *resolvecb,
	user: nullable *opaque
) (ev::req | error) = {
	// TODO: Reduce duplication with net::dial
	let state = alloc(resolve_state {
		cb = cb,
		user = user,
		...
	});

	const (addr, port) = match (dial::splitaddr(addr, service)) {
	case let svc: (str, u16) =>
		yield svc;
	case dial::invalid_address =>
		resolve_finish(state, invalid_address);
		return ev::req { ... };
	};

	if (service == "unknown" && port == 0) {
		resolve_finish(state, unknown_service);
		return ev::req { ... };
	};

	if (port == 0) {
		match (lookup_service(proto, service)) {
		case let p: u16 =>
			port = p;
		case void => yield;
		};
	};

	// TODO:
	// - Consult /etc/services
	// - Fetch the SRV record

	if (port == 0) {
		resolve_finish(state, unknown_service);
		return ev::req { ... };
	};

	match (ip::parse(addr)) {
	case let addr: ip::addr =>
		const addrs = alloc([addr]);
		resolve_finish(state, (addrs, port));
		return ev::req { ... };
	case ip::invalid => yield;
	};

	const addrs = hosts::lookup(addr)?;
	if (len(addrs) != 0) {
		resolve_finish(state, (addrs, port));
		return ev::req { ... };
	};

	state.port = port;
	return resolve_dns(state, loop, addr);
};

fn resolve_dns(
	state: *resolve_state,
	loop: *ev::loop,
	addr: str,
) (ev::req | error) = {
	const domain = dns::parse_domain(addr);
	defer free(domain);

	let rand: []u8 = [0, 0];
	random::buffer(rand);
	let id = *(&rand[0]: *u16);

	const query6 = dns::message {
		header = dns::header {
			id = id,
			op = dns::op {
				qr = dns::qr::QUERY,
				opcode = dns::opcode::QUERY,
				rd = true,
				...
			},
			qdcount = 1,
			...
		},
		questions = [
			dns::question {
				qname = domain,
				qtype = dns::qtype::AAAA,
				qclass = dns::qclass::IN,
			},
		],
		...
	};
	const query4 = dns::message {
		header = dns::header {
			id = id + 1,
			op = dns::op {
				qr = dns::qr::QUERY,
				opcode = dns::opcode::QUERY,
				rd = true,
				...
			},
			qdcount = 1,
			...
		},
		questions = [
			dns::question {
				qname = domain,
				qtype = dns::qtype::A,
				qclass = dns::qclass::IN,
			},
		],
		...
	};

	state.r6 = edns::query(loop, &query6, &query_cb_v6, state)?;
	state.r4 = edns::query(loop, &query4, &query_cb_v4, state)?;
	return ev::mkreq(&resolve_cancel, state);
};

fn resolve_finish(st: *resolve_state, r: (([]ip::addr, u16) | error)) void = {
	const user = st.user;
	const cb = st.cb;
	if (r is error) {
		free(st.ip);
	};
	free(st);
	cb(user, r);
};

fn resolve_cancel(req: *ev::req) void = {
	const state = req.user: *resolve_state;
	ev::cancel(&state.r4);
	ev::cancel(&state.r6);
	free(state.ip);
	free(state);
};

fn query_cb_v4(user: nullable *opaque, r: (*dns::message | dns::error)) void = {
	let state = user: *resolve_state;
	state.r4 = ev::req { ... };

	match (r) {
	case let err: dns::error =>
		ev::cancel(&state.r6);
		resolve_finish(state, err);
		return;
	case let msg: *dns::message =>
		collect_answers(&state.ip, &msg.answers);
		state.nq += 1;
	};

	if (state.nq < 2) {
		return;
	};
	resolve_finish(state, (state.ip, state.port));
};

fn query_cb_v6(user: nullable *opaque, r: (*dns::message | dns::error)) void = {
	let state = user: *resolve_state;
	state.r6 = ev::req { ... };

	match (r) {
	case let err: dns::error =>
		ev::cancel(&state.r4);
		resolve_finish(state, err);
		return;
	case let msg: *dns::message =>
		collect_answers(&state.ip, &msg.answers);
		state.nq += 1;
	};

	if (state.nq < 2) {
		return;
	};
	resolve_finish(state, (state.ip, state.port));
};

fn collect_answers(addrs: *[]ip::addr, answers: *[]dns::rrecord) void = {
	for (let i = 0z; i < len(answers); i += 1) {
		match (answers[i].rdata) {
	for (let answer &.. answers) {
		match (answer.rdata) {
		case let addr: dns::aaaa =>
			append(addrs, addr: ip::addr);
		case let addr: dns::a =>
			append(addrs, addr: ip::addr);
		case => void;
		};
	};
};
use errors;
use ev;
use net;
use net::dns;
use net::ip;
use net::udp;
use time;
use unix::resolvconf;

// TODO: Let users customize this?
def TIMEOUT: time::duration = 3 * time::SECOND;

// Callback for a [[query]] operation.
export type querycb = fn(
	user: nullable *opaque,
	r: (*dns::message | dns::error),
) void;

type qstate = struct {
	buf: [512]u8,
	socket4: *ev::file,
	socket6: *ev::file,
	r4: ev::req,
	r6: ev::req,
	timer: *ev::file,
	rid: u16,
	cb: *querycb,
	user: nullable *opaque,
};

// Performs a DNS query against the provided set of DNS servers, or the list of
// servers from /etc/resolv.conf if none are specified. The user must free the
// message passed to the callback with [[net::dns::message_free]].
export fn query(
	loop: *ev::loop,
	query: *dns::message,
	cb: *querycb,
	user: nullable *opaque,
	servers: ip::addr...
) (ev::req | dns::error | net::error | errors::error) = {
	if (len(servers) == 0) {
		servers = resolvconf::load();
	};
	if (len(servers) == 0) {
		// Fall back to localhost
		servers = [ip::LOCAL_V6, ip::LOCAL_V4];
	};

	const socket4 = ev::listen_udp(loop, ip::ANY_V4, 0)?;
	const socket6 = ev::listen_udp(loop, ip::ANY_V6, 0)?;
	const timeout = ev::newtimer(loop, &timeoutcb, time::clock::MONOTONIC)?;
	let state = alloc(qstate {
		socket4 = socket4,
		socket6 = socket6,
		timer = timeout,
		rid = query.header.id,
		cb = cb,
		user = user,
		...
	});
	const z = dns::encode(state.buf, query)?;
	ev::setuser(socket4, state);
	ev::setuser(socket6, state);
	ev::setuser(timeout, state);
	ev::timer_configure(timeout, TIMEOUT, 0);

	// Note: the initial set of requests is sent directly through net::udp
	// as it is assumed they can fit into the kernel's internal send buffer
	// and will finish without blocking
	const buf = state.buf[..z];
	for (let i = 0z; i < len(servers); i += 1) match (servers[i]) {
	case ip::addr4 =>
		udp::sendto(ev::getfd(socket4), buf, servers[i], 53)?;
	case ip::addr6 =>
		udp::sendto(ev::getfd(socket6), buf, servers[i], 53)?;
	for (const server .. servers) {
		match (server) {
		case ip::addr4 =>
			udp::sendto(ev::getfd(socket4), buf, server, 53)?;
		case ip::addr6 =>
			udp::sendto(ev::getfd(socket6), buf, server, 53)?;
		};
	};

	state.r4 = ev::recvfrom(socket4, &qrecvcb, state.buf);
	state.r6 = ev::recvfrom(socket6, &qrecvcb, state.buf);
	return ev::mkreq(&query_cancel, state);
};

fn query_cancel(req: *ev::req) void = {
	const q = req.user: *qstate;
	query_destroy(q);
};

fn query_destroy(q: *qstate) void = {
	ev::cancel(&q.r4);
	ev::cancel(&q.r6);
	ev::close(q.socket4);
	ev::close(q.socket6);
	ev::close(q.timer);
	free(q);
};

fn query_complete(q: *qstate, r: (*dns::message | dns::error)) void = {
	const cb = q.cb;
	const user = q.user;
	query_destroy(q);
	cb(user, r);
};

fn timeoutcb(file: *ev::file) void = {
	const q = ev::getuser(file): *qstate;
	query_complete(q, errors::timeout);
};

fn qrecvcb(file: *ev::file, r: ((size, ip::addr, u16) | net::error)) void = {
	const q = ev::getuser(file): *qstate;
	let req: *ev::req = if (file == q.socket4) &q.r4 else &q.r6;
	*req = ev::req { ... };

	const (z, addr, port) = match (r) {
	case let r: (size, ip::addr, u16) =>
		yield r;
	case let err: net::error =>
		query_complete(q, err);
		return;
	};

	const resp = match (dns::decode(q.buf[..z])) {
	case dns::format =>
		*req = ev::recvfrom(file, &qrecvcb, q.buf);
		return;
	case let msg: *dns::message =>
		yield msg;
	};
	defer dns::message_free(resp);

	if (resp.header.id != q.rid || resp.header.op.qr != dns::qr::RESPONSE) {
		*req = ev::recvfrom(file, &qrecvcb, q.buf);
		return;
	};

	if (!resp.header.op.tc) {
		query_complete(q, resp);
		return;
	};

	abort(); // TODO: retry over TCP
};