From 58ddd667f4d367c500ca33f63c2de427a5a0126b Mon Sep 17 00:00:00 2001 From: Drew DeVault Date: Thu, 28 Sep 2023 10:39:14 +0200 Subject: [PATCH] ev::dial: initial commit --- cmd/dnsclient/main.ha | 66 +++++++++++++---------------------------------------- cmd/hget/main.ha | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++ ev/dial/TODO | 1 + ev/dial/dial.ha | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++ ev/dial/ip.ha | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++ ev/dial/registry.ha | 138 +++++++++++++++++++++++++++++++++++++++++++++++++++++ ev/dial/resolve.ha | 224 +++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/cmd/dnsclient/main.ha b/cmd/dnsclient/main.ha index 615eb880a7fb2d7f41990ab9e42e58821525c441..8fd721b37806432c912921ba2e12aa6c43074458 100644 --- a/cmd/dnsclient/main.ha +++ b/cmd/dnsclient/main.ha @@ -1,8 +1,6 @@ -use crypto::random; use ev; -use edns = ev::dns; +use ev::dial; use fmt; -use net::dns; use net::ip; use os; @@ -10,59 +8,27 @@ export fn main() void = { const loop = ev::newloop()!; defer ev::finish(&loop); - let rand: []u8 = [0, 0]; - random::buffer(rand); - let id = *(&rand[0]: *u16); - - const domain = dns::parse_domain(os::args[1]); - defer free(domain); - - const query = dns::message { - header = dns::header { - id = id, - op = dns::op { - qr = dns::qr::QUERY, - opcode = dns::opcode::QUERY, - rd = true, - ... - }, - qdcount = 1, - ... - }, - questions = [ - dns::question { - qname = domain, - qtype = dns::qtype::A, - qclass = dns::qclass::IN, - }, - ], - ... - }; - - edns::query(&loop, &query, &querycb, &loop)!; + const addr = os::args[1], svc = os::args[2]; + dial::resolve(&loop, "tcp", addr, svc, &resolvecb, &loop)!; for (ev::dispatch(&loop, -1)!) void; }; -fn querycb(user: nullable *opaque, r: (*dns::message | dns::error)) void = { +fn resolvecb( + user: nullable *opaque, + r: (([]ip::addr, u16) | dial::error), +) void = { const loop = user: *ev::loop; - const resp = match (r) { - case let msg: *dns::message => - yield msg; - case let err: dns::error => - fmt::errorln("DNS error:", dns::strerror(err))!; - ev::stop(loop); - return; + defer ev::stop(loop); + + const (ip, port) = match (r) { + case let ip: ([]ip::addr, u16) => + yield ip; + case let err: dial::error => + fmt::fatal("Dial error:", dial::strerror(err)); }; - for (let i = 0z; i < len(resp.answers); i += 1) { - match (resp.answers[i].rdata) { - case let addr: dns::aaaa => - fmt::println(ip::string(addr))!; - case let addr: dns::a => - fmt::println(ip::string(addr))!; - case => void; - }; + for (let i = 0z; i < len(ip); i += 1) { + fmt::printfln("{}:{}", ip::string(ip[i]), port)!; }; - ev::stop(loop); }; diff --git a/cmd/hget/main.ha b/cmd/hget/main.ha new file mode 100644 index 0000000000000000000000000000000000000000..30ad67c22006904f8e1435a723551260fa4d9d37 --- /dev/null +++ b/cmd/hget/main.ha @@ -0,0 +1,84 @@ +use ev; +use ev::dial; +use fmt; +use io; +use net::uri; +use os; +use strings; + +type state = struct { + loop: *ev::loop, + exit: int, + buf: []u8, +}; + +export fn main() void = { + const loop = ev::newloop()!; + defer ev::finish(&loop); + + let state = state { + loop = &loop, + exit = os::status::SUCCESS, + buf = [], + }; + + const uri = net::uri::parse("http://example.org")!; + defer uri::finish(&uri); + + dial::dial_uri(&loop, "tcp", &uri, &dialcb, &state)!; + + for (ev::dispatch(&loop, -1)!) void; + + os::exit(state.exit); +}; + +fn error(state: *state, details: str) void = { + fmt::errorfln("Error: {}", details)!; + state.exit = os::status::FAILURE; + ev::stop(state.loop); +}; + +fn dialcb(user: nullable *opaque, r: (*ev::file | dial::error)) void = { + let state = user: *state; + const file = match (r) { + case let file: *ev::file => + yield file; + case let err: dial::error => + error(state, dial::strerror(err)); + return; + }; + ev::setuser(file, state); + state.buf = strings::toutf8("GET / HTTP/1.1\r\n" + "Host: example.org\r\n" + "Connection: close\r\n\r\n"); + ev::write(file, &writecb, state.buf); +}; + +fn writecb(file: *ev::file, r: (size | io::error)) void = { + const state = ev::getuser(file): *state; + const z = match (r) { + case let z: size => + yield z; + case let err: io::error => + error(state, io::strerror(err)); + return; + }; + assert(z == len(state.buf)); + ev::read(file, &readcb, state.buf); +}; + +fn readcb(file: *ev::file, r: (size | io::EOF | io::error)) void = { + const state = ev::getuser(file): *state; + const z = match (r) { + case let z: size => + yield z; + case io::EOF => + ev::stop(state.loop); + return; + case let err: io::error => + error(state, io::strerror(err)); + return; + }; + io::writeall(os::stdout, state.buf[..z])!; + ev::read(file, &readcb, state.buf); +}; diff --git a/ev/dial/TODO b/ev/dial/TODO new file mode 100644 index 0000000000000000000000000000000000000000..99f8bff2a9f1ab864656a12fd95cfccd1d65bae5 --- /dev/null +++ b/ev/dial/TODO @@ -0,0 +1 @@ +TODO: reduce code duplication with diff --git a/ev/dial/dial.ha b/ev/dial/dial.ha new file mode 100644 index 0000000000000000000000000000000000000000..c6461874566068264560143efa4e8ffcd0b83b70 --- /dev/null +++ b/ev/dial/dial.ha @@ -0,0 +1,74 @@ +use ev; +use fmt; +use net; +use net::dial; +use net::ip; +use net::uri; + +// Callback for a [[dial]] operation. +export type dialcb = fn(user: nullable *opaque, r: (*ev::file | error)) void; + +// Dials a remote address, establishing a connection and returning the resulting +// [[net::socket]] to the callback. The proto parameter should be the transport +// protocol (e.g. "tcp"), the address parameter should be the remote address, +// and the service should be the name of the service, or the default port to +// use. +// +// See also [[net::dial::dial]]. +export fn dial( + loop: *ev::loop, + proto: str, + address: str, + service: str, + cb: *dialcb, + user: nullable *opaque, +) (ev::req | error) = { + for (let i = 0z; i < len(default_protocols); i += 1) { + const p = default_protocols[i]; + if (p.name == proto) { + return p.dial(loop, address, service, cb, user); + }; + }; + for (let i = 0z; i < len(protocols); i += 1) { + const p = protocols[i]; + if (p.name == proto) { + return p.dial(loop, address, service, cb, user); + }; + }; + return net::unknownproto: net::error; +}; + +def HOST_MAX: size = 255; + +// Performs a [[dial]] operation for a given URI, taking the service name from +// the URI scheme and forming an address from the URI host and port. +// +// See also [[net::dial::uri]]. +export fn dial_uri( + loop: *ev::loop, + proto: str, + uri: *uri::uri, + cb: *dialcb, + user: nullable *opaque, +) (ev::req | error) = { + if (uri.host is str && len(uri.host as str) > HOST_MAX) { + return invalid_address; + }; + static let addr: [HOST_MAX + len("[]:65535")]u8 = [0...]; + + const colon = if (uri.port != 0) ":" else ""; + const port: fmt::formattable = if (uri.port != 0) uri.port else ""; + + let addr = match (uri.host) { + case let host: str => + yield fmt::bsprintf(addr, "{}{}{}", host, colon, port); + case let ip: ip::addr4 => + const host = ip::string(ip); + yield fmt::bsprintf(addr, "{}{}{}", host, colon, port); + case let ip: ip::addr6 => + const host = ip::string(ip); + yield fmt::bsprintf(addr, "[{}]{}{}", host, colon, port); + }; + + return dial(loop, proto, addr, uri.scheme, cb, user); +}; diff --git a/ev/dial/ip.ha b/ev/dial/ip.ha new file mode 100644 index 0000000000000000000000000000000000000000..36e28e8ec4429a6d5dc3e5e522de1be41ad94244 --- /dev/null +++ b/ev/dial/ip.ha @@ -0,0 +1,115 @@ +// License: MPL-2.0 +// (c) 2021-2023 Drew DeVault +// (c) 2021 Bor Grošelj Simić +// (c) 2021 Ember Sawady +// +// Provides default dialers for tcp and udp +use errors; +use ev; +use net; +use net::ip; +use net::tcp; +use net::udp; + +type tcp_dialer = struct { + loop: *ev::loop, + cb: *dialcb, + user: nullable *opaque, + req: ev::req, + ip: []ip::addr, + port: u16, +}; + +fn dial_tcp( + loop: *ev::loop, + addr: str, + service: str, + cb: *dialcb, + user: nullable *opaque, +) (ev::req | error) = { + let state = alloc(tcp_dialer { + loop = loop, + cb = cb, + user = user, + ... + }); + + const req = resolve(loop, "tcp", addr, + service, &dial_tcp_resolvecb, state)?; + state.req = req; + return ev::mkreq(&dial_tcp_cancel, state); +}; + +fn dial_tcp_resolvecb( + user: nullable *opaque, + r: (([]ip::addr, u16) | error), +) void = { + let state = user: *tcp_dialer; + state.req = ev::req { ... }; + const (ip, port) = match (r) { + case let r: ([]ip::addr, u16) => + yield r; + case let err: error => + dial_tcp_complete(state, err); + return; + }; + + state.ip = ip; + state.port = port; + dial_tcp_connect(state); +}; + +fn dial_tcp_connect(state: *tcp_dialer) void = { + // TODO: Select IPs from a round-robin, or re-attempt on other IPs? + // TODO: Detect supported networks? i.e. v4/v6 + const req = match (ev::connect_tcp(state.loop, + &dial_tcp_connectcb, + state.ip[0], state.port, + state)) { + case let err: (net::error | errors::error) => + dial_tcp_complete(state, err); + return; + case let req: ev::req => + yield req; + }; + state.req = req; +}; + +fn dial_tcp_connectcb( + r: (*ev::file | net::error), + user: nullable *opaque, +) void = { + let state = user: *tcp_dialer; + match (r) { + case let sock: *ev::file => + ev::setuser(sock, null); + dial_tcp_complete(state, sock); + case let err: net::error => + dial_tcp_complete(state, err); + }; +}; + +fn dial_tcp_cancel(req: *ev::req) void = { + let state = req: *tcp_dialer; + ev::cancel(&state.req); + free(state.ip); + free(state); +}; + +fn dial_tcp_complete(state: *tcp_dialer, r: (*ev::file | error)) void = { + const cb = state.cb; + const user = state.user; + free(state.ip); + free(state); + cb(user, r); +}; + +fn dial_udp( + loop: *ev::loop, + addr: str, + service: str, + cb: *dialcb, + user: nullable *opaque, +) (ev::req | error) = { + abort(); // TODO +}; diff --git a/ev/dial/registry.ha b/ev/dial/registry.ha new file mode 100644 index 0000000000000000000000000000000000000000..09fe1e877de62a6df0e1eb7916f6ce5907f40615 --- /dev/null +++ b/ev/dial/registry.ha @@ -0,0 +1,138 @@ +// License: MPL-2.0 +// (c) 2021-2023 Drew DeVault +// (c) 2021 Ember Sawady +use errors; +use ev; +use net; +use net::dns; +use unix::hosts; + +// Returned if the address parameter was invalid, for example if it specifies an +// invalid port number. +export type invalid_address = !void; + +// Returned if the service parameter does not name a service known to the +// system. +export type unknown_service = !void; + +// Errors which can occur from dial. +export type error = !(invalid_address | unknown_service + | net::error | dns::error | hosts::error | errors::error); + +// Converts an [[error]] to a human-readable string. +export fn strerror(err: error) const str = { + match (err) { + case invalid_address => + return "Attempted to dial an invalid address"; + case unknown_service => + return "Unknown service"; + case let err: net::error => + return net::strerror(err); + case let err: dns::error => + return dns::strerror(err); + case let err: hosts::error => + return hosts::strerror(err); + }; +}; + +// A dialer is a function which implements dial for a specific protocol. +export type dialer = fn( + loop: *ev::loop, + addr: str, + service: str, + cb: *dialcb, + user: nullable *opaque, +) (ev::req | error); + +type protocol = struct { + name: str, + dial: *dialer, +}; + +type service = struct { + proto: str, + name: str, + alias: []str, + port: u16, +}; + +let default_protocols: [_]protocol = [ + protocol { name = "tcp", dial = &dial_tcp }, + protocol { name = "udp", dial = &dial_udp }, +]; + +let default_services: [_]service = [ + service { proto = "tcp", name = "ssh", alias = [], port = 22 }, + service { proto = "tcp", name = "smtp", alias = ["mail"], port = 25 }, + service { proto = "tcp", name = "domain", alias = ["dns"], port = 53 }, + service { proto = "tcp", name = "http", alias = ["www"], port = 80 }, + service { proto = "tcp", name = "imap2", alias = ["imap"], port = 143 }, + service { proto = "tcp", name = "https", alias = [], port = 443 }, + service { proto = "tcp", name = "submission", alias = [], port = 587 }, + service { proto = "tcp", name = "imaps", alias = [], port = 993 }, + service { proto = "udp", name = "domain", alias = ["dns"], port = 53 }, + service { proto = "udp", name = "ntp", alias = [], port = 123 }, +]; + +let protocols: []protocol = []; +let services: []service = []; + +@fini fn fini() void = { + free(protocols); + free(services); +}; + +// Registers a new transport-level protocol (e.g. TCP) with the dialer. The name +// should be statically allocated. +export fn registerproto(name: str, dial: *dialer) void = { + append(protocols, protocol { + name = name, + dial = dial, + }); +}; + +// Registers a new application-level service (e.g. SSH) with the dialer. Note +// that the purpose of services is simply to establish the default outgoing +// port for TCP and UDP connections. The name and alias list should be +// statically allocated. +export fn registersvc( + proto: str, + name: str, + alias: []str, + port: u16, +) void = { + append(services, service { + proto = proto, + name = name, + alias = alias, + port = port, + }); +}; + +fn lookup_service(proto: str, service: str) (u16 | void) = { + for (let i = 0z; i < len(default_services); i += 1) { + const serv = &default_services[i]; + if (service_match(serv, proto, service)) { + return serv.port; + }; + }; + + for (let i = 0z; i < len(services); i += 1) { + const serv = &services[i]; + if (service_match(serv, proto, service)) { + return serv.port; + }; + }; +}; + +fn service_match(candidate: *service, proto: str, service: str) bool = { + if (candidate.name == service) { + return true; + }; + for (let j = 0z; j < len(candidate.alias); j += 1) { + if (candidate.alias[j] == service) { + return true; + }; + }; + return false; +}; diff --git a/ev/dial/resolve.ha b/ev/dial/resolve.ha new file mode 100644 index 0000000000000000000000000000000000000000..b6470317906c88c91b631c35d91e2241b23155d6 --- /dev/null +++ b/ev/dial/resolve.ha @@ -0,0 +1,224 @@ +use crypto::random; +use errors; +use ev; +use edns = ev::dns; +use net; +use net::ip; +use net::dial; +use net::dns; +use unix::hosts; + +// Callback from a [[resolve]] operation. +export type resolvecb = fn( + user: nullable *opaque, + r: (([]ip::addr, u16) | error), +) void; + +type resolve_state = struct { + user: nullable *opaque, + cb: *resolvecb, + r4: ev::req, + r6: ev::req, + nq: uint, + ip: []ip::addr, + port: u16, +}; + +// Performs DNS resolution on a given address string for a given service, +// including /etc/hosts lookup and SRV resolution, and returns a list of +// candidate IP addresses and the appropriate port, or an error, to the +// callback. +// +// The caller must free the [[net::ip::addr]] slice. +export fn resolve( + loop: *ev::loop, + proto: str, + addr: str, + service: str, + cb: *resolvecb, + user: nullable *opaque +) (ev::req | error) = { + // TODO: Reduce duplication with net::dial + let state = alloc(resolve_state { + cb = cb, + user = user, + ... + }); + + const (addr, port) = match (dial::splitaddr(addr, service)) { + case let svc: (str, u16) => + yield svc; + case dial::invalid_address => + resolve_finish(state, invalid_address); + return ev::req { ... }; + }; + + if (service == "unknown" && port == 0) { + resolve_finish(state, unknown_service); + return ev::req { ... }; + }; + + if (port == 0) { + match (lookup_service(proto, service)) { + case let p: u16 => + port = p; + case void => yield; + }; + }; + + // TODO: + // - Consult /etc/services + // - Fetch the SRV record + + if (port == 0) { + resolve_finish(state, unknown_service); + return ev::req { ... }; + }; + + match (ip::parse(addr)) { + case let addr: ip::addr => + const addrs = alloc([addr]); + resolve_finish(state, (addrs, port)); + return ev::req { ... }; + case ip::invalid => yield; + }; + + const addrs = hosts::lookup(addr)?; + if (len(addrs) != 0) { + resolve_finish(state, (addrs, port)); + return ev::req { ... }; + }; + + state.port = port; + return resolve_dns(state, loop, addr); +}; + +fn resolve_dns( + state: *resolve_state, + loop: *ev::loop, + addr: str, +) (ev::req | error) = { + const domain = dns::parse_domain(addr); + defer free(domain); + + let rand: []u8 = [0, 0]; + random::buffer(rand); + let id = *(&rand[0]: *u16); + + const query6 = dns::message { + header = dns::header { + id = id, + op = dns::op { + qr = dns::qr::QUERY, + opcode = dns::opcode::QUERY, + rd = true, + ... + }, + qdcount = 1, + ... + }, + questions = [ + dns::question { + qname = domain, + qtype = dns::qtype::AAAA, + qclass = dns::qclass::IN, + }, + ], + ... + }; + const query4 = dns::message { + header = dns::header { + id = id + 1, + op = dns::op { + qr = dns::qr::QUERY, + opcode = dns::opcode::QUERY, + rd = true, + ... + }, + qdcount = 1, + ... + }, + questions = [ + dns::question { + qname = domain, + qtype = dns::qtype::A, + qclass = dns::qclass::IN, + }, + ], + ... + }; + + state.r6 = edns::query(loop, &query6, &query_cb_v6, state)?; + state.r4 = edns::query(loop, &query4, &query_cb_v4, state)?; + return ev::mkreq(&resolve_cancel, state); +}; + +fn resolve_finish(st: *resolve_state, r: (([]ip::addr, u16) | error)) void = { + const user = st.user; + const cb = st.cb; + if (r is error) { + free(st.ip); + }; + free(st); + cb(user, r); +}; + +fn resolve_cancel(req: *ev::req) void = { + const state = req.user: *resolve_state; + ev::cancel(&state.r4); + ev::cancel(&state.r6); + free(state.ip); + free(state); +}; + +fn query_cb_v4(user: nullable *opaque, r: (*dns::message | dns::error)) void = { + let state = user: *resolve_state; + state.r4 = ev::req { ... }; + + match (r) { + case let err: dns::error => + ev::cancel(&state.r6); + resolve_finish(state, err); + return; + case let msg: *dns::message => + collect_answers(&state.ip, &msg.answers); + state.nq += 1; + }; + + if (state.nq < 2) { + return; + }; + resolve_finish(state, (state.ip, state.port)); +}; + +fn query_cb_v6(user: nullable *opaque, r: (*dns::message | dns::error)) void = { + let state = user: *resolve_state; + state.r6 = ev::req { ... }; + + match (r) { + case let err: dns::error => + ev::cancel(&state.r4); + resolve_finish(state, err); + return; + case let msg: *dns::message => + collect_answers(&state.ip, &msg.answers); + state.nq += 1; + }; + + if (state.nq < 2) { + return; + }; + resolve_finish(state, (state.ip, state.port)); +}; + +fn collect_answers(addrs: *[]ip::addr, answers: *[]dns::rrecord) void = { + for (let i = 0z; i < len(answers); i += 1) { + match (answers[i].rdata) { + case let addr: dns::aaaa => + append(addrs, addr: ip::addr); + case let addr: dns::a => + append(addrs, addr: ip::addr); + case => void; + }; + }; +}; -- 2.48.1