Connect via WebSocket at this URL.
- \\ - ; -} diff --git a/src/spider.zig b/src/spider.zig index 408c432..4374fec 100644 --- a/src/spider.zig +++ b/src/spider.zig @@ -13,7 +13,7 @@ const BATCH_CREATION_DELAY_MS: u64 = 500; const RECONNECT_DELAY_MS: u64 = 10_000; const MAX_RECONNECT_DELAY_MS: u64 = 3600_000; const BLACKOUT_MS: u64 = 24 * 3600_000; -const QUICK_DISCONNECT_MS: i64 = 120_000; +const QUICK_DISCONNECT_MS: i64 = 30_000; const RATE_LIMIT_BACKOFF_MS: u64 = 60_000; const MAX_RATE_LIMIT_BACKOFF_MS: u64 = 1800_000; const CATCHUP_WINDOW_MS: i64 = 1800_000; @@ -24,6 +24,7 @@ pub const Spider = struct { store: *Store, broadcaster: *Broadcaster, running: std.atomic.Value(bool), + global_shutdown: *std.atomic.Value(bool), relays: std.StringArrayHashMap(RelayConn), follow_pubkeys: std.ArrayListUnmanaged([32]u8), follow_mutex: std.Thread.Mutex, @@ -35,6 +36,7 @@ pub const Spider = struct { config: *const Config, store: *Store, broadcaster: *Broadcaster, + global_shutdown: *std.atomic.Value(bool), ) !Spider { var spider = Spider{ .allocator = allocator, @@ -42,6 +44,7 @@ pub const Spider = struct { .store = store, .broadcaster = broadcaster, .running = std.atomic.Value(bool).init(false), + .global_shutdown = global_shutdown, .relays = std.StringArrayHashMap(RelayConn).init(allocator), .follow_pubkeys = .{}, .follow_mutex = .{}, @@ -100,10 +103,28 @@ pub const Spider = struct { try self.threads.append(self.allocator, refresh_thread); } + fn shouldRun(self: *Spider) bool { + return self.running.load(.acquire) and !self.global_shutdown.load(.acquire); + } + + fn interruptibleSleep(self: *Spider, ms: u64) void { + const interval_ms: u64 = 100; + var remaining = ms; + while (remaining > 0 and self.shouldRun()) { + const sleep_ms = @min(remaining, interval_ms); + std.Thread.sleep(sleep_ms * @as(u64, std.time.ns_per_ms)); + remaining -|= sleep_ms; + } + } + pub fn stop(self: *Spider) void { log.info("Spider stopping...", .{}); self.running.store(false, .release); + for (self.relays.values()) |*conn| { + conn.closeClient(); + } + for (self.threads.items) |thread| { thread.join(); } @@ -167,11 +188,16 @@ pub const Spider = struct { client.writeText(@constCast(req_msg)) catch continue; + client.readTimeout(5000) catch {}; + var events_received: u64 = 0; var got_kind3 = false; - var msg_count: usize = 0; - while (msg_count < 20) : (msg_count += 1) { - const message = client.read() catch break; + const bootstrap_start = std.time.milliTimestamp(); + while (std.time.milliTimestamp() - bootstrap_start < 10_000) { + const message = client.read() catch |err| { + if (err == error.WouldBlock) continue; + break; + }; if (message) |msg| { defer client.done(msg); if (msg.data.len > 0) { @@ -243,12 +269,11 @@ pub const Spider = struct { } fn refreshLoop(self: *Spider) void { - const interval_s: u64 = @max(1, @as(u64, self.config.spider_sync_interval)); - const interval_ms: u64 = interval_s * 1000; - while (self.running.load(.acquire)) { - std.Thread.sleep(interval_ms * std.time.ns_per_ms); + const interval_ms: u64 = @max(1000, @as(u64, self.config.spider_sync_interval) * 1000); + while (self.shouldRun()) { + self.interruptibleSleep(interval_ms); - if (!self.running.load(.acquire)) break; + if (!self.shouldRun()) break; const old_count = blk: { self.follow_mutex.lock(); @@ -275,13 +300,13 @@ pub const Spider = struct { var conn = self.relays.getPtr(relay_url) orelse return; - while (self.running.load(.acquire)) { + while (self.shouldRun()) { if (conn.blackout_until > 0) { const now = std.time.milliTimestamp(); if (now < conn.blackout_until) { const wait_ms: u64 = @intCast(conn.blackout_until - now); log.info("{s}: In blackout for {d}ms more", .{ relay_url, wait_ms }); - std.Thread.sleep(@as(u64, @min(wait_ms, 60_000)) * std.time.ns_per_ms); + self.interruptibleSleep(@min(wait_ms, 60_000)); continue; } conn.blackout_until = 0; @@ -293,22 +318,21 @@ pub const Spider = struct { const success = self.connectAndSubscribe(conn, relay_url); const connection_duration = std.time.milliTimestamp() - connect_start; + if (!self.shouldRun()) break; + if (success) { if (connection_duration < QUICK_DISCONNECT_MS) { - log.warn("{s}: Quick disconnect after {d}ms", .{ relay_url, connection_duration }); + log.warn("{s}: Quick disconnect after {d}ms, waiting {d}ms", .{ relay_url, connection_duration, conn.reconnect_delay_ms }); + self.interruptibleSleep(conn.reconnect_delay_ms); conn.reconnect_delay_ms = @min(conn.reconnect_delay_ms * 2, MAX_RECONNECT_DELAY_MS); } else { log.info("{s}: Disconnected after {d}ms uptime", .{ relay_url, connection_duration }); - if (conn.reconnect_delay_ms > RECONNECT_DELAY_MS * 8) { - conn.reconnect_delay_ms = conn.reconnect_delay_ms / 2; - } else { - conn.reconnect_delay_ms = RECONNECT_DELAY_MS; - } + conn.reconnect_delay_ms = RECONNECT_DELAY_MS; + self.interruptibleSleep(5000); } - std.Thread.sleep(5 * std.time.ns_per_s); } else { log.warn("{s}: Connection failed, waiting {d}ms", .{ relay_url, conn.reconnect_delay_ms }); - std.Thread.sleep(conn.reconnect_delay_ms * std.time.ns_per_ms); + self.interruptibleSleep(conn.reconnect_delay_ms); conn.reconnect_delay_ms = @min(conn.reconnect_delay_ms * 2, MAX_RECONNECT_DELAY_MS); } @@ -326,7 +350,8 @@ pub const Spider = struct { if (conn.isRateLimited()) { const wait_ms: u64 = @intCast(@max(0, conn.rate_limit_until - std.time.milliTimestamp())); log.info("{s}: Rate limited, waiting {d}ms", .{ relay_url, wait_ms }); - std.Thread.sleep(wait_ms * std.time.ns_per_ms); + self.interruptibleSleep(wait_ms); + if (!self.shouldRun()) return false; } const parsed = parseRelayUrl(relay_url) orelse { @@ -352,6 +377,9 @@ pub const Spider = struct { return false; }; + conn.active_client = &client; + defer conn.active_client = null; + client.handshake(parsed.path, .{ .headers = host_header, }) catch |err| { @@ -363,15 +391,22 @@ pub const Spider = struct { conn.state = .connected; const now = std.time.milliTimestamp(); - if (conn.last_connect == 0) { - if (!self.performNegentropySync(&client, relay_url)) { - log.err("{s}: Connection lost during negentropy sync", .{relay_url}); + client.readTimeout(1000) catch {}; + + if (!self.shouldRun()) return false; + + if (conn.last_connect == 0 and conn.negentropy_supported) { + if (!self.performNegentropySync(&client, relay_url, conn)) { return false; } } else if (conn.last_disconnect > 0) { - self.performCatchup(&client, conn, relay_url); + if (!self.performCatchup(&client, conn, relay_url)) { + return false; + } } + if (!self.shouldRun()) return false; + conn.last_connect = now; conn.clearRateLimit(); @@ -387,7 +422,7 @@ pub const Spider = struct { return true; } - fn performCatchup(self: *Spider, client: *websocket.Client, conn: *RelayConn, relay_url: []const u8) void { + fn performCatchup(self: *Spider, client: *websocket.Client, conn: *RelayConn, relay_url: []const u8) bool { const since_ts = conn.last_disconnect - CATCHUP_WINDOW_MS; const until_ts = std.time.milliTimestamp() + CATCHUP_WINDOW_MS; const since_unix = @divFloor(since_ts, 1000); @@ -398,25 +433,29 @@ pub const Spider = struct { self.follow_mutex.lock(); defer self.follow_mutex.unlock(); - if (self.follow_pubkeys.items.len == 0) return; + if (self.follow_pubkeys.items.len == 0) return true; var msg_buf: [65536]u8 = undefined; const msg = buildCatchupReqMessage(&msg_buf, self.follow_pubkeys.items, since_unix, until_unix) catch |err| { log.err("{s}: Failed to build catch-up REQ: {}", .{ relay_url, err }); - return; + return true; }; client.writeText(@constCast(msg)) catch |err| { log.err("{s}: Failed to send catch-up REQ: {}", .{ relay_url, err }); - return; + return false; }; var catchup_events: u64 = 0; const catchup_start = std.time.milliTimestamp(); const catchup_timeout_ms: i64 = 30_000; + var read_error = false; while (std.time.milliTimestamp() - catchup_start < catchup_timeout_ms) { - const message = client.read() catch break; + const message = client.read() catch { + read_error = true; + break; + }; if (message) |msg_data| { defer client.done(msg_data); if (msg_data.data.len > 0) { @@ -438,14 +477,23 @@ pub const Spider = struct { } } + if (read_error) { + log.info("{s}: Catch-up finished with {d} events (connection lost)", .{ relay_url, catchup_events }); + return false; + } + var close_buf: [64]u8 = undefined; - const close_msg = std.fmt.bufPrint(&close_buf, "[\"CLOSE\",\"catchup\"]", .{}) catch return; + const close_msg = std.fmt.bufPrint(&close_buf, "[\"CLOSE\",\"catchup\"]", .{}) catch { + log.info("{s}: Catch-up finished with {d} events", .{ relay_url, catchup_events }); + return true; + }; client.writeText(@constCast(close_msg)) catch {}; log.info("{s}: Catch-up finished with {d} events", .{ relay_url, catchup_events }); + return true; } - fn performNegentropySync(self: *Spider, client: *websocket.Client, relay_url: []const u8) bool { + fn performNegentropySync(self: *Spider, client: *websocket.Client, relay_url: []const u8, conn: *RelayConn) bool { self.follow_mutex.lock(); const pubkeys = self.allocator.dupe([32]u8, self.follow_pubkeys.items) catch { self.follow_mutex.unlock(); @@ -456,6 +504,12 @@ pub const Spider = struct { if (pubkeys.len == 0) return true; + if (pubkeys.len > 100) { + log.info("{s}: Too many pubkeys ({d}) for negentropy, skipping initial sync", .{ relay_url, pubkeys.len }); + conn.negentropy_supported = false; + return true; + } + var local_storage = negentropy.VectorStorage.init(self.allocator); defer local_storage.deinit(); @@ -471,6 +525,7 @@ pub const Spider = struct { var local_count: usize = 0; while (iter.next() catch null) |json| { + if (local_count % 1000 == 0 and !self.shouldRun()) return false; var event = nostr.Event.parse(json) catch continue; defer event.deinit(); local_storage.insert(@intCast(event.createdAt()), event.id()) catch continue; @@ -480,7 +535,7 @@ pub const Spider = struct { log.info("{s}: Negentropy sync starting with {d} local events", .{ relay_url, local_count }); - var filter_buf: [32768]u8 = undefined; + var filter_buf: [65536]u8 = undefined; const filter_json = buildNegentropyFilter(&filter_buf, pubkeys) catch { log.err("{s}: Failed to build negentropy filter", .{relay_url}); return true; @@ -509,6 +564,8 @@ pub const Spider = struct { return false; }; + client.readTimeout(1000) catch {}; + var have_ids: std.ArrayListUnmanaged([32]u8) = .{}; defer have_ids.deinit(self.allocator); var need_ids: std.ArrayListUnmanaged([32]u8) = .{}; @@ -522,11 +579,15 @@ pub const Spider = struct { var connection_alive = true; while (std.time.milliTimestamp() - sync_start < sync_timeout_ms) { + if (!self.shouldRun()) return false; + if (!got_response and std.time.milliTimestamp() - sync_start > initial_timeout_ms) { - log.warn("{s}: No negentropy response, relay may not support NIP-77", .{relay_url}); + log.warn("{s}: No negentropy response, disabling for this relay", .{relay_url}); + conn.negentropy_supported = false; return true; } - const message = client.read() catch { + const message = client.read() catch |err| { + if (err == error.WouldBlock) continue; connection_alive = false; break; }; @@ -535,7 +596,8 @@ pub const Spider = struct { if (msg.data.len == 0) continue; if (std.mem.startsWith(u8, msg.data, "[\"NEG-ERR\"")) { - log.warn("{s}: Negentropy not supported, skipping historical sync", .{relay_url}); + log.warn("{s}: Negentropy not supported, disabling for this relay", .{relay_url}); + conn.negentropy_supported = false; var close_buf: [64]u8 = undefined; const close_msg = std.fmt.bufPrint(&close_buf, "[\"NEG-CLOSE\",\"neg-sync\"]", .{}) catch break; client.writeText(@constCast(close_msg)) catch {}; @@ -586,27 +648,36 @@ pub const Spider = struct { } } - if (connection_alive) { - var close_buf: [64]u8 = undefined; - const close_msg = std.fmt.bufPrint(&close_buf, "[\"NEG-CLOSE\",\"neg-sync\"]", .{}) catch return true; - client.writeText(@constCast(close_msg)) catch {}; + if (!connection_alive) { + if (!got_response) { + log.warn("{s}: Connection lost before negentropy response, will retry", .{relay_url}); + } + return false; } + var close_buf: [64]u8 = undefined; + const close_msg = std.fmt.bufPrint(&close_buf, "[\"NEG-CLOSE\",\"neg-sync\"]", .{}) catch return true; + client.writeText(@constCast(close_msg)) catch {}; + log.info("{s}: Need {d} events, have {d} events to skip", .{ relay_url, need_ids.items.len, have_ids.items.len }); - if (connection_alive and need_ids.items.len > 0) { - self.fetchEventsByIds(client, relay_url, need_ids.items); + if (need_ids.items.len > 0) { + if (!self.fetchEventsByIds(client, relay_url, need_ids.items)) { + return false; + } } - return connection_alive; + return true; } - fn fetchEventsByIds(self: *Spider, client: *websocket.Client, relay_url: []const u8, ids: [][32]u8) void { + fn fetchEventsByIds(self: *Spider, client: *websocket.Client, relay_url: []const u8, ids: [][32]u8) bool { const batch_size: usize = 100; var fetched: u64 = 0; var i: usize = 0; while (i < ids.len) { + if (!self.shouldRun()) return true; + const end = @min(i + batch_size, ids.len); const batch = ids[i..end]; @@ -627,15 +698,24 @@ pub const Spider = struct { }; if (req_msg) |msg| { - client.writeText(@constCast(msg)) catch break; + client.writeText(@constCast(msg)) catch { + log.info("{s}: Fetched {d} events via negentropy sync (connection lost)", .{ relay_url, fetched }); + return false; + }; } else { i = end; continue; } const fetch_start = std.time.milliTimestamp(); + var read_error = false; while (std.time.milliTimestamp() - fetch_start < 30_000) { - const message = client.read() catch break; + if (!self.shouldRun()) return true; + const message = client.read() catch |err| { + if (err == error.WouldBlock) continue; + read_error = true; + break; + }; if (message) |msg| { defer client.done(msg); if (std.mem.startsWith(u8, msg.data, "[\"EOSE\"")) break; @@ -645,14 +725,23 @@ pub const Spider = struct { } } + if (read_error) { + log.info("{s}: Fetched {d} events via negentropy sync (connection lost)", .{ relay_url, fetched }); + return false; + } + var close_buf: [64]u8 = undefined; - const close_msg = std.fmt.bufPrint(&close_buf, "[\"CLOSE\",\"fetch\"]", .{}) catch break; + const close_msg = std.fmt.bufPrint(&close_buf, "[\"CLOSE\",\"fetch\"]", .{}) catch { + i = end; + continue; + }; client.writeText(@constCast(close_msg)) catch {}; i = end; } log.info("{s}: Fetched {d} events via negentropy sync", .{ relay_url, fetched }); + return true; } fn sendSubscriptions(self: *Spider, client: *websocket.Client, relay_url: []const u8) !void { @@ -668,6 +757,8 @@ pub const Spider = struct { var i: usize = 0; while (i < self.follow_pubkeys.items.len) { + if (!self.shouldRun()) return error.Shutdown; + const end = @min(i + BATCH_SIZE, self.follow_pubkeys.items.len); const batch = self.follow_pubkeys.items[i..end]; @@ -698,8 +789,11 @@ pub const Spider = struct { fn readLoop(self: *Spider, client: *websocket.Client, relay_url: []const u8) void { var events_received: u64 = 0; - while (self.running.load(.acquire)) { + client.readTimeout(1000) catch {}; + + while (self.shouldRun()) { const message = client.read() catch |err| { + if (err == error.WouldBlock) continue; if (err == error.Closed or err == error.ConnectionResetByPeer) { log.info("{s}: Connection closed", .{relay_url}); } else { @@ -850,6 +944,8 @@ const RelayConn = struct { last_disconnect: i64 = 0, rate_limit_until: i64 = 0, rate_limit_backoff_ms: u64 = RATE_LIMIT_BACKOFF_MS, + negentropy_supported: bool = true, + active_client: ?*websocket.Client = null, const State = enum { disconnected, connecting, connected }; @@ -857,6 +953,10 @@ const RelayConn = struct { return .{ .url = url }; } + fn closeClient(self: *RelayConn) void { + _ = self; + } + fn applyRateLimit(self: *RelayConn) void { const now = std.time.milliTimestamp(); self.rate_limit_until = now + @as(i64, @intCast(self.rate_limit_backoff_ms)); diff --git a/src/subscriptions.zig b/src/subscriptions.zig index bd7a88f..1c96c96 100644 --- a/src/subscriptions.zig +++ b/src/subscriptions.zig @@ -33,6 +33,13 @@ pub const Subscriptions = struct { try self.connections.put(conn.id, conn); } + pub fn tryAddConnection(self: *Subscriptions, conn: *Connection, max_connections: usize) !void { + self.rwlock.lock(); + defer self.rwlock.unlock(); + if (self.connections.count() >= max_connections) return error.TooManyConnections; + try self.connections.put(conn.id, conn); + } + pub fn removeConnection(self: *Subscriptions, conn_id: u64) void { self.rwlock.lock(); defer self.rwlock.unlock(); @@ -56,6 +63,36 @@ pub const Subscriptions = struct { return self.connections.get(conn_id); } + pub fn withConnection(self: *Subscriptions, conn_id: u64, comptime func: fn (*Connection) void) void { + self.rwlock.lockShared(); + defer self.rwlock.unlockShared(); + if (self.connections.get(conn_id)) |conn| { + func(conn); + } + } + + pub fn sendToConnection(self: *Subscriptions, conn_id: u64, data: []const u8) bool { + self.rwlock.lockShared(); + defer self.rwlock.unlockShared(); + if (self.connections.get(conn_id)) |conn| { + return conn.send(data); + } + return false; + } + + pub fn closeIdleConnection(self: *Subscriptions, conn_id: u64, notice: []const u8) bool { + self.rwlock.lockShared(); + defer self.rwlock.unlockShared(); + if (self.connections.get(conn_id)) |conn| { + conn.sendDirect(notice); + conn.stopWriteQueue(); + conn.clearDirectWriter(); + conn.shutdown(); + return true; + } + return false; + } + pub fn connectionCount(self: *Subscriptions) usize { self.rwlock.lockShared(); defer self.rwlock.unlockShared(); diff --git a/src/tcp_server.zig b/src/tcp_server.zig new file mode 100644 index 0000000..31902d6 --- /dev/null +++ b/src/tcp_server.zig @@ -0,0 +1,391 @@ +const std = @import("std"); +const net = std.net; +const posix = std.posix; +const nostr = @import("nostr.zig"); +const ws = nostr.ws; + +const Config = @import("config.zig").Config; +const MsgHandler = @import("handler.zig").Handler; +const Subscriptions = @import("subscriptions.zig").Subscriptions; +const Connection = @import("connection.zig").Connection; +const nip11 = @import("nip11.zig"); +const rate_limiter = @import("rate_limiter.zig"); +const write_queue = @import("write_queue.zig"); + +const WsWriter = struct { + stream: net.Stream, + mutex: std.Thread.Mutex = .{}, + failed: bool = false, + + fn write(ctx: *anyopaque, data: []const u8) void { + const self: *WsWriter = @ptrCast(@alignCast(ctx)); + self.mutex.lock(); + defer self.mutex.unlock(); + if (self.failed) return; + self.writeWsFrame(data) catch { + self.failed = true; + }; + } + + fn writeWsFrame(self: *WsWriter, data: []const u8) !void { + var header: [14]u8 = undefined; + var header_len: usize = 2; + + header[0] = 0x81; + + if (data.len < 126) { + header[1] = @intCast(data.len); + } else if (data.len < 65536) { + header[1] = 126; + header[2] = @intCast((data.len >> 8) & 0xFF); + header[3] = @intCast(data.len & 0xFF); + header_len = 4; + } else { + header[1] = 127; + const len64: u64 = data.len; + header[2] = @intCast((len64 >> 56) & 0xFF); + header[3] = @intCast((len64 >> 48) & 0xFF); + header[4] = @intCast((len64 >> 40) & 0xFF); + header[5] = @intCast((len64 >> 32) & 0xFF); + header[6] = @intCast((len64 >> 24) & 0xFF); + header[7] = @intCast((len64 >> 16) & 0xFF); + header[8] = @intCast((len64 >> 8) & 0xFF); + header[9] = @intCast(len64 & 0xFF); + header_len = 10; + } + + var iovecs = [_]std.posix.iovec_const{ + .{ .base = &header, .len = header_len }, + .{ .base = data.ptr, .len = data.len }, + }; + _ = try self.stream.writev(&iovecs); + } +}; + +pub const TcpServer = struct { + allocator: std.mem.Allocator, + config: *const Config, + handler: *MsgHandler, + subs: *Subscriptions, + + next_id: u64 = 0, + mutex: std.Thread.Mutex = .{}, + + listener: ?net.Server = null, + shutdown: *std.atomic.Value(bool), + + conn_limiter: rate_limiter.ConnectionLimiter, + ip_filter: rate_limiter.IpFilter, + + pub fn init( + allocator: std.mem.Allocator, + config: *const Config, + handler: *MsgHandler, + subs: *Subscriptions, + shutdown: *std.atomic.Value(bool), + ) !TcpServer { + var ip_filter = rate_limiter.IpFilter.init(allocator); + try ip_filter.loadWhitelist(config.ip_whitelist); + try ip_filter.loadBlacklist(config.ip_blacklist); + + return .{ + .allocator = allocator, + .config = config, + .handler = handler, + .subs = subs, + .shutdown = shutdown, + .conn_limiter = rate_limiter.ConnectionLimiter.init(allocator, config.max_connections_per_ip), + .ip_filter = ip_filter, + }; + } + + pub fn deinit(self: *TcpServer) void { + if (self.listener) |*l| { + l.deinit(); + self.listener = null; + } + self.conn_limiter.deinit(); + self.ip_filter.deinit(); + } + + pub fn run(self: *TcpServer) !void { + const address = try net.Address.parseIp(self.config.host, self.config.port); + self.listener = try address.listen(.{ + .reuse_address = true, + }); + + std.log.info("Server running on {s}:{d}", .{ self.config.host, self.config.port }); + + const idle_thread = std.Thread.spawn(.{}, idleTimeoutThread, .{ self, self.shutdown }) catch null; + defer if (idle_thread) |t| t.join(); + + while (!self.shutdown.load(.acquire)) { + const conn = self.listener.?.accept() catch |err| { + if (err == error.SocketNotListening) break; + continue; + }; + + const thread = std.Thread.spawn(.{}, handleConnection, .{ self, conn }) catch |err| { + std.log.warn("Failed to spawn connection thread: {}", .{err}); + conn.stream.close(); + continue; + }; + thread.detach(); + } + + std.log.info("Shutting down server...", .{}); + if (self.listener) |*l| { + l.deinit(); + self.listener = null; + } + std.Thread.sleep(200 * std.time.ns_per_ms); + } + + fn idleTimeoutThread(self: *TcpServer, shutdown: *std.atomic.Value(bool)) void { + const check_interval_s: u64 = 30; + var seconds_waited: u64 = 0; + + while (!shutdown.load(.acquire)) { + std.Thread.sleep(std.time.ns_per_s); + if (shutdown.load(.acquire)) break; + seconds_waited += 1; + if (seconds_waited < check_interval_s) continue; + seconds_waited = 0; + + if (self.config.idle_seconds == 0) continue; + + const idle_conn_ids = self.subs.getIdleConnections(self.config.idle_seconds); + defer self.allocator.free(idle_conn_ids); + + for (idle_conn_ids) |conn_id| { + var buf: [128]u8 = undefined; + const notice = nostr.RelayMsg.notice("connection closed: idle timeout", &buf) catch continue; + if (self.subs.closeIdleConnection(conn_id, notice)) { + std.log.debug("Closed idle connection {d}", .{conn_id}); + } + } + } + } + + pub fn stop(self: *TcpServer) void { + if (self.listener) |*l| { + l.deinit(); + self.listener = null; + } + } + + fn handleConnection(self: *TcpServer, conn: net.Server.Connection) void { + defer conn.stream.close(); + + // Check shutdown early to avoid accessing freed resources + if (self.shutdown.load(.acquire)) return; + + var addr_buf: [64]u8 = undefined; + const client_ip = extractIp(conn.address, &addr_buf); + + if (self.shutdown.load(.acquire)) return; + + if (!self.ip_filter.isAllowed(client_ip)) { + return; + } + + if (!self.conn_limiter.canConnect(client_ip)) { + return; + } + + var buf: [8192]u8 = undefined; + const n = conn.stream.read(&buf) catch return; + if (n == 0) return; + + const req_data = buf[0..n]; + + if (isWebsocketUpgrade(req_data)) { + self.handleWebsocket(conn, client_ip, req_data) catch |err| { + std.log.debug("Websocket error: {}", .{err}); + }; + } else { + self.handleHttp(conn, req_data) catch {}; + } + } + + fn handleWebsocket(self: *TcpServer, conn: net.Server.Connection, client_ip: []const u8, initial_data: []const u8) !void { + const TCP_NODELAY = 1; + posix.setsockopt(conn.stream.handle, posix.IPPROTO.TCP, TCP_NODELAY, &std.mem.toBytes(@as(i32, 1))) catch {}; + + const req, const consumed = try ws.handshake.Req.parse(initial_data); + _ = consumed; + + const accept = ws.handshake.secAccept(req.key); + var response_buf: [256]u8 = undefined; + const response = try std.fmt.bufPrint(&response_buf, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {s}\r\n\r\n", .{&accept}); + try conn.stream.writeAll(response); + + var ws_writer = WsWriter{ .stream = conn.stream }; + + self.mutex.lock(); + const conn_id = self.next_id; + self.next_id += 1; + self.mutex.unlock(); + + const connection = try self.allocator.create(Connection); + connection.init(self.allocator, conn_id); + connection.setClientIp(client_ip); + connection.setSocketHandle(conn.stream.handle); + connection.setDirectWriter(WsWriter.write, @ptrCast(&ws_writer)); + connection.startWriteQueue(WsWriter.write, @ptrCast(&ws_writer)); + + self.subs.tryAddConnection(connection, self.config.max_connections) catch |err| { + connection.stopWriteQueue(); + connection.clearDirectWriter(); + connection.deinit(); + self.allocator.destroy(connection); + return err; + }; + self.conn_limiter.addConnection(client_ip); + + defer { + connection.stopWriteQueue(); + connection.clearDirectWriter(); + self.subs.removeConnection(conn_id); + self.conn_limiter.removeConnection(client_ip); + connection.deinit(); + self.allocator.destroy(connection); + } + + if (self.config.auth_required or self.config.auth_to_write) { + var auth_buf: [256]u8 = undefined; + const auth_msg = nostr.RelayMsg.auth(&connection.auth_challenge, &auth_buf) catch return; + connection.sendDirect(auth_msg); + connection.challenge_sent = true; + } + + var frame_buf: [65536]u8 = undefined; + var read_pos: usize = 0; + + while (!self.shutdown.load(.acquire)) { + if (read_pos >= frame_buf.len) { + var close_response: [16]u8 = undefined; + const close_frame = ws.Frame{ .fin = 1, .opcode = .close, .payload = &.{}, .mask = 0 }; + const close_len = close_frame.encode(&close_response, 1002); + conn.stream.writeAll(close_response[0..close_len]) catch {}; + return; + } + + const bytes_read = conn.stream.read(frame_buf[read_pos..]) catch |err| { + if (err == error.ConnectionResetByPeer or err == error.BrokenPipe) break; + return err; + }; + if (bytes_read == 0) break; + + read_pos += bytes_read; + connection.touch(); + + while (read_pos > 0) { + if (try self.checkOversizedFrame(connection, conn, frame_buf[0..read_pos])) return; + + const frame, const frame_len = ws.Frame.parse(frame_buf[0..read_pos]) catch |err| { + if (err == error.SplitBuffer) break; + return err; + }; + + try frame.assertValid(false); + + if (frame.opcode == .close) { + var close_response: [16]u8 = undefined; + const close_frame = ws.Frame{ .fin = 1, .opcode = .close, .payload = &.{}, .mask = 0 }; + const close_len = close_frame.encode(&close_response, 1000); + conn.stream.writeAll(close_response[0..close_len]) catch {}; + return; + } else if (frame.opcode == .ping) { + var pong_buf: [256]u8 = undefined; + const pong_frame = ws.Frame{ .fin = 1, .opcode = .pong, .payload = frame.payload, .mask = 0 }; + const pong_len = pong_frame.encode(&pong_buf, 0); + conn.stream.writeAll(pong_buf[0..pong_len]) catch {}; + } else if (frame.opcode == .text or frame.opcode == .binary) { + self.handler.handle(connection, frame.payload); + } + + if (frame_len < read_pos) { + std.mem.copyForwards(u8, &frame_buf, frame_buf[frame_len..read_pos]); + } + read_pos -= frame_len; + } + } + } + + fn checkOversizedFrame(self: *TcpServer, connection: *Connection, conn: net.Server.Connection, data: []const u8) !bool { + if (data.len < 2) return false; + const payload_len_byte: u8 = data[1] & 0b0111_1111; + const payload_len: u64 = switch (payload_len_byte) { + 126 => blk: { + if (data.len < 4) return false; + break :blk std.mem.readInt(u16, data[2..4], .big); + }, + 127 => blk: { + if (data.len < 10) return false; + break :blk std.mem.readInt(u64, data[2..10], .big); + }, + else => payload_len_byte, + }; + if (payload_len > self.config.max_message_size) { + var notice_buf: [256]u8 = undefined; + const notice = nostr.RelayMsg.notice("error: message too large", ¬ice_buf) catch { + return true; + }; + connection.sendDirect(notice); + var close_response: [16]u8 = undefined; + const close_frame = ws.Frame{ .fin = 1, .opcode = .close, .payload = &.{}, .mask = 0 }; + const close_len = close_frame.encode(&close_response, 1009); + conn.stream.writeAll(close_response[0..close_len]) catch {}; + return true; + } + return false; + } + + fn handleHttp(self: *TcpServer, conn: net.Server.Connection, req_data: []const u8) !void { + const accepts_json = std.mem.indexOf(u8, req_data, "application/nostr+json") != null; + + if (accepts_json) { + var response_buf: [4096]u8 = undefined; + var content_buf: [2048]u8 = undefined; + + var content_stream = std.io.fixedBufferStream(&content_buf); + try nip11.write(self.config, content_stream.writer()); + const content = content_stream.getWritten(); + + const response = try std.fmt.bufPrint(&response_buf, "HTTP/1.1 200 OK\r\nContent-Type: application/nostr+json\r\nAccess-Control-Allow-Origin: *\r\nContent-Length: {d}\r\n\r\n{s}", .{ content.len, content }); + try conn.stream.writeAll(response); + } else { + const html = + \\ + \\Connect via WebSocket at this URL.
+ \\ + ; + const response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: " ++ std.fmt.comptimePrint("{d}", .{html.len}) ++ "\r\n\r\n" ++ html; + try conn.stream.writeAll(response); + } + } + + fn isWebsocketUpgrade(data: []const u8) bool { + return std.ascii.indexOfIgnoreCase(data, "upgrade: websocket") != null; + } + + fn extractIp(address: net.Address, buf: []u8) []const u8 { + const formatted = std.fmt.bufPrint(buf, "{any}", .{address}) catch return "unknown"; + if (std.mem.lastIndexOf(u8, formatted, ":")) |colon| { + return formatted[0..colon]; + } + return formatted; + } + + pub fn send(self: *TcpServer, conn_id: u64, data: []const u8) void { + _ = self.subs.sendToConnection(conn_id, data); + } + + pub fn connectionCount(self: *TcpServer) usize { + return self.subs.connectionCount(); + } +}; diff --git a/src/write_queue.zig b/src/write_queue.zig index 21bfd82..89c43dd 100644 --- a/src/write_queue.zig +++ b/src/write_queue.zig @@ -1,38 +1,44 @@ const std = @import("std"); -const httpz = @import("httpz"); -const websocket = httpz.websocket; +const net = std.net; + +pub const WriteFn = *const fn (ctx: *anyopaque, data: []const u8) void; pub const WriteQueue = struct { - ws_conn: ?*websocket.Conn, + write_fn: ?WriteFn, + write_ctx: ?*anyopaque, dropped_count: std.atomic.Value(u64), allocator: std.mem.Allocator, pub fn init(allocator: std.mem.Allocator) WriteQueue { return .{ - .ws_conn = null, + .write_fn = null, + .write_ctx = null, .dropped_count = std.atomic.Value(u64).init(0), .allocator = allocator, }; } - pub fn start(self: *WriteQueue, ws_conn: *websocket.Conn) void { - self.ws_conn = ws_conn; + pub fn start(self: *WriteQueue, write_fn: WriteFn, write_ctx: *anyopaque) void { + self.write_fn = write_fn; + self.write_ctx = write_ctx; } pub fn stop(self: *WriteQueue) void { - self.ws_conn = null; + self.write_fn = null; + self.write_ctx = null; } pub fn enqueue(self: *WriteQueue, data: []const u8) bool { - if (self.ws_conn) |conn| { - conn.write(data) catch { - _ = self.dropped_count.fetchAdd(1, .monotonic); - return false; - }; - return true; - } - _ = self.dropped_count.fetchAdd(1, .monotonic); - return false; + const write_fn = self.write_fn orelse { + _ = self.dropped_count.fetchAdd(1, .monotonic); + return false; + }; + const ctx = self.write_ctx orelse { + _ = self.dropped_count.fetchAdd(1, .monotonic); + return false; + }; + write_fn(ctx, data); + return true; } pub fn droppedCount(self: *WriteQueue) u64 { @@ -40,6 +46,6 @@ pub const WriteQueue = struct { } pub fn queueDepth(_: *WriteQueue) usize { - return 0; // No queue + return 0; } };