504 lines
24 KiB
Zig
504 lines
24 KiB
Zig
//! zap.App takes the zap.Endpoint concept one step further: instead of having
|
|
//! only per-endpoint instance data (fields of your Endpoint struct), endpoints
|
|
//! in a zap.App easily share a global 'App Context'.
|
|
//!
|
|
//! In addition to the global App Context, all Endpoint request handlers also
|
|
//! receive an arena allocator for easy, care-free allocations. There is one
|
|
//! arena allocator per thread, and arenas are reset after each request.
|
|
//!
|
|
//! Just like regular / legacy zap.Endpoints, returning errors from request
|
|
//! handlers is OK. It's decided on a per-endpoint basis how errors are dealt
|
|
//! with, via the ErrorStrategy enum field.
|
|
//!
|
|
//! See `App.Create()`.
|
|
const std = @import("std");
|
|
const Allocator = std.mem.Allocator;
|
|
const ArenaAllocator = std.heap.ArenaAllocator;
|
|
const Thread = std.Thread;
|
|
const RwLock = Thread.RwLock;
|
|
|
|
const zap = @import("zap");
|
|
const Request = zap.Request;
|
|
const HttpListener = zap.HttpListener;
|
|
const ErrorStrategy = zap.Endpoint.ErrorStrategy;
|
|
|
|
pub const AppOpts = struct {
|
|
/// ErrorStrategy for (optional) request handler if no endpoint matches
|
|
default_error_strategy: ErrorStrategy = .log_to_console,
|
|
arena_retain_capacity: usize = 16 * 1024 * 1024,
|
|
};
|
|
|
|
/// creates an App with custom app context
|
|
///
|
|
/// About App Contexts:
|
|
///
|
|
/// ```zig
|
|
/// const MyContext = struct {
|
|
/// // You may (optionally) define the following global handlers:
|
|
/// pub fn unhandledRequest(_: *MyContext, _: Allocator, _: Request) anyerror!void {}
|
|
/// pub fn unhandledError(_: *MyContext, _: Request, _: anyerror) void {}
|
|
/// };
|
|
/// ```
|
|
pub fn Create(
|
|
/// Your user-defined "Global App Context" type
|
|
comptime Context: type,
|
|
) type {
|
|
return struct {
|
|
const App = @This();
|
|
|
|
// we make the following fields static so we can access them from a
|
|
// context-free, pure zap request handler
|
|
const InstanceData = struct {
|
|
context: *Context = undefined,
|
|
gpa: Allocator = undefined,
|
|
opts: AppOpts = undefined,
|
|
endpoints: std.ArrayListUnmanaged(*Endpoint.Interface) = .empty,
|
|
|
|
there_can_be_only_one: bool = false,
|
|
track_arenas: std.AutoHashMapUnmanaged(Thread.Id, ArenaAllocator) = .empty,
|
|
track_arena_lock: RwLock = .{},
|
|
|
|
/// the internal http listener
|
|
listener: HttpListener = undefined,
|
|
|
|
/// function pointer to handler for otherwise unhandled requests.
|
|
/// Will automatically be set if your Context provides an
|
|
/// `unhandledRequest` function of type `fn(*Context, Allocator,
|
|
/// Request) !void`.
|
|
unhandled_request: ?*const fn (*Context, Allocator, Request) anyerror!void = null,
|
|
|
|
/// function pointer to handler for unhandled errors.
|
|
/// Errors are unhandled if they are not logged but raised by the
|
|
/// ErrorStrategy. Will automatically be set if your Context
|
|
/// provides an `unhandledError` function of type `fn(*Context,
|
|
/// Allocator, Request, anyerror) void`.
|
|
unhandled_error: ?*const fn (*Context, Request, anyerror) void = null,
|
|
};
|
|
var _static: InstanceData = .{};
|
|
|
|
pub const Endpoint = struct {
|
|
pub const Interface = struct {
|
|
call: *const fn (*Interface, Request) anyerror!void = undefined,
|
|
path: []const u8,
|
|
destroy: *const fn (*Interface, Allocator) void = undefined,
|
|
};
|
|
pub fn Bind(ArbitraryEndpoint: type) type {
|
|
return struct {
|
|
endpoint: *ArbitraryEndpoint,
|
|
interface: Interface,
|
|
|
|
// tbh: unnecessary, since we have it in _static
|
|
app_context: *Context,
|
|
|
|
const Bound = @This();
|
|
|
|
pub fn unwrap(interface: *Interface) *Bound {
|
|
const self: *Bound = @alignCast(@fieldParentPtr("interface", interface));
|
|
return self;
|
|
}
|
|
|
|
pub fn destroy(interface: *Interface, allocator: Allocator) void {
|
|
const self: *Bound = @alignCast(@fieldParentPtr("interface", interface));
|
|
allocator.destroy(self);
|
|
}
|
|
|
|
pub fn onRequestInterface(interface: *Interface, r: Request) !void {
|
|
var self: *Bound = Bound.unwrap(interface);
|
|
var arena = try get_arena();
|
|
try self.onRequest(arena.allocator(), self.app_context, r);
|
|
_ = arena.reset(.{ .retain_with_limit = _static.opts.arena_retain_capacity });
|
|
}
|
|
|
|
pub fn onRequest(self: *Bound, arena: Allocator, app_context: *Context, r: Request) !void {
|
|
// TODO: simplitfy this with @tagName?
|
|
const ret = switch (r.methodAsEnum()) {
|
|
.GET => callHandlerIfExist("get", self.endpoint, arena, app_context, r),
|
|
.POST => callHandlerIfExist("post", self.endpoint, arena, app_context, r),
|
|
.PUT => callHandlerIfExist("put", self.endpoint, arena, app_context, r),
|
|
.DELETE => callHandlerIfExist("delete", self.endpoint, arena, app_context, r),
|
|
.PATCH => callHandlerIfExist("patch", self.endpoint, arena, app_context, r),
|
|
.OPTIONS => callHandlerIfExist("options", self.endpoint, arena, app_context, r),
|
|
.HEAD => callHandlerIfExist("head", self.endpoint, arena, app_context, r),
|
|
else => error.UnsupportedHtmlRequestMethod,
|
|
};
|
|
if (ret) {
|
|
// handled without error
|
|
} else |err| {
|
|
switch (self.endpoint.*.error_strategy) {
|
|
.raise => return err,
|
|
.log_to_response => return r.sendError(err, if (@errorReturnTrace()) |t| t.* else null, 505),
|
|
.log_to_console => zap.log.err(
|
|
"Error in {} {s} : {}",
|
|
.{ Bound, r.method orelse "(no method)", err },
|
|
),
|
|
}
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn init(ArbitraryEndpoint: type, endpoint: *ArbitraryEndpoint) Endpoint.Bind(ArbitraryEndpoint) {
|
|
checkEndpointType(ArbitraryEndpoint);
|
|
const BoundEp = Endpoint.Bind(ArbitraryEndpoint);
|
|
return .{
|
|
.endpoint = endpoint,
|
|
.interface = .{
|
|
.path = endpoint.path,
|
|
.call = BoundEp.onRequestInterface,
|
|
.destroy = BoundEp.destroy,
|
|
},
|
|
.app_context = _static.context,
|
|
};
|
|
}
|
|
|
|
pub fn checkEndpointType(T: type) void {
|
|
if (@hasField(T, "path")) {
|
|
if (@FieldType(T, "path") != []const u8) {
|
|
@compileError(@typeName(@FieldType(T, "path")) ++ " has wrong type, expected: []const u8");
|
|
}
|
|
} else {
|
|
@compileError(@typeName(T) ++ " has no path field");
|
|
}
|
|
|
|
if (@hasField(T, "error_strategy")) {
|
|
if (@FieldType(T, "error_strategy") != ErrorStrategy) {
|
|
@compileError(@typeName(@FieldType(T, "error_strategy")) ++ " has wrong type, expected: zap.Endpoint.ErrorStrategy");
|
|
}
|
|
} else {
|
|
@compileError(@typeName(T) ++ " has no error_strategy field");
|
|
}
|
|
|
|
const methods_to_check = [_][]const u8{
|
|
"get",
|
|
"post",
|
|
"put",
|
|
"delete",
|
|
"patch",
|
|
"options",
|
|
"head",
|
|
};
|
|
const params_to_check = [_]type{
|
|
*T,
|
|
Allocator,
|
|
*Context,
|
|
Request,
|
|
};
|
|
inline for (methods_to_check) |method| {
|
|
if (@hasDecl(T, method)) {
|
|
const Method = @TypeOf(@field(T, method));
|
|
const method_info = @typeInfo(Method);
|
|
if (method_info != .@"fn") {
|
|
@compileError("Expected `" ++ @typeName(T) ++ "." ++ method ++ "` to be a request handler method, got: " ++ @typeName(Method));
|
|
}
|
|
|
|
// now check parameters
|
|
const params = method_info.@"fn".params;
|
|
if (params.len != params_to_check.len) {
|
|
@compileError(std.fmt.comptimePrint(
|
|
"Expected method `{s}.{s}` to have {d} parameters, got {d}",
|
|
.{
|
|
@typeName(T),
|
|
method,
|
|
params_to_check.len,
|
|
params.len,
|
|
},
|
|
));
|
|
}
|
|
|
|
inline for (params_to_check, 0..) |param_type_expected, i| {
|
|
if (params[i].type.? != param_type_expected) {
|
|
@compileError(std.fmt.comptimePrint(
|
|
"Expected parameter {d} of method {s}.{s} to be {s}, got {s}",
|
|
.{
|
|
i + 1,
|
|
@typeName(T),
|
|
method,
|
|
@typeName(param_type_expected),
|
|
@typeName(params[i].type.?),
|
|
},
|
|
));
|
|
}
|
|
}
|
|
|
|
const ret_type = method_info.@"fn".return_type.?;
|
|
const ret_info = @typeInfo(ret_type);
|
|
if (ret_info != .error_union) {
|
|
@compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: " ++ @typeName(ret_type));
|
|
}
|
|
if (ret_info.error_union.payload != void) {
|
|
@compileError("Expected return type of method `" ++ @typeName(T) ++ "." ++ method ++ "` to be !void, got: !" ++ @typeName(ret_info.error_union.payload));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Wrap an endpoint with an Authenticator
|
|
pub fn Authenticating(EndpointType: type, Authenticator: type) type {
|
|
return struct {
|
|
authenticator: *Authenticator,
|
|
ep: *EndpointType,
|
|
path: []const u8,
|
|
error_strategy: ErrorStrategy,
|
|
const AuthenticatingEndpoint = @This();
|
|
|
|
/// Init the authenticating endpoint. Pass in a pointer to the endpoint
|
|
/// you want to wrap, and the Authenticator that takes care of authenticating
|
|
/// requests.
|
|
pub fn init(e: *EndpointType, authenticator: *Authenticator) AuthenticatingEndpoint {
|
|
return .{
|
|
.authenticator = authenticator,
|
|
.ep = e,
|
|
.path = e.path,
|
|
.error_strategy = e.error_strategy,
|
|
};
|
|
}
|
|
|
|
/// Authenticates GET requests using the Authenticator.
|
|
pub fn get(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("get", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates POST requests using the Authenticator.
|
|
pub fn post(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("post", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates PUT requests using the Authenticator.
|
|
pub fn put(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: zap.Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("put", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates DELETE requests using the Authenticator.
|
|
pub fn delete(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: zap.Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("delete", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates PATCH requests using the Authenticator.
|
|
pub fn patch(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: zap.Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("patch", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates OPTIONS requests using the Authenticator.
|
|
pub fn options(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: zap.Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("options", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
|
|
/// Authenticates HEAD requests using the Authenticator.
|
|
pub fn head(self: *AuthenticatingEndpoint, arena: Allocator, context: *Context, request: zap.Request) anyerror!void {
|
|
try switch (self.authenticator.authenticateRequest(&request)) {
|
|
.AuthFailed => callHandlerIfExist("unauthorized", self.ep, arena, context, request),
|
|
.AuthOK => callHandlerIfExist("head", self.ep, arena, context, request),
|
|
.Handled => {},
|
|
};
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
pub const ListenerSettings = struct {
|
|
/// IP interface, e.g. 0.0.0.0
|
|
interface: [*c]const u8 = null,
|
|
/// IP port to listen on
|
|
port: usize,
|
|
public_folder: ?[]const u8 = null,
|
|
max_clients: ?isize = null,
|
|
max_body_size: ?usize = null,
|
|
timeout: ?u8 = null,
|
|
tls: ?zap.Tls = null,
|
|
};
|
|
|
|
pub fn init(gpa_: Allocator, context_: *Context, opts_: AppOpts) !void {
|
|
if (_static.there_can_be_only_one) {
|
|
return error.OnlyOneAppAllowed;
|
|
}
|
|
_static.context = context_;
|
|
_static.gpa = gpa_;
|
|
_static.opts = opts_;
|
|
_static.there_can_be_only_one = true;
|
|
|
|
// set unhandled_request callback if provided by Context
|
|
if (@hasDecl(Context, "unhandledRequest")) {
|
|
// try if we can use it
|
|
const Unhandled = @TypeOf(@field(Context, "unhandledRequest"));
|
|
const Expected = fn (_: *Context, _: Allocator, _: Request) anyerror!void;
|
|
if (Unhandled != Expected) {
|
|
@compileError("`unhandledRequest` method of " ++ @typeName(Context) ++ " has wrong type:\n" ++ @typeName(Unhandled) ++ "\nexpected:\n" ++ @typeName(Expected));
|
|
}
|
|
_static.unhandled_request = Context.unhandledRequest;
|
|
}
|
|
if (@hasDecl(Context, "unhandledError")) {
|
|
// try if we can use it
|
|
const Unhandled = @TypeOf(@field(Context, "unhandledError"));
|
|
const Expected = fn (_: *Context, _: Request, _: anyerror) void;
|
|
if (Unhandled != Expected) {
|
|
@compileError("`unhandledError` method of " ++ @typeName(Context) ++ " has wrong type:\n" ++ @typeName(Unhandled) ++ "\nexpected:\n" ++ @typeName(Expected));
|
|
}
|
|
_static.unhandled_error = Context.unhandledError;
|
|
}
|
|
}
|
|
|
|
pub fn deinit() void {
|
|
// we created endpoint wrappers but only tracked their interfaces
|
|
// hence, we need to destroy the wrappers through their interfaces
|
|
if (false) {
|
|
var it = _static.endpoints.iterator();
|
|
while (it.next()) |kv| {
|
|
const interface = kv.value_ptr;
|
|
interface.*.destroy(_static.gpa);
|
|
}
|
|
} else {
|
|
for (_static.endpoints.items) |interface| {
|
|
interface.destroy(interface, _static.gpa);
|
|
}
|
|
}
|
|
_static.endpoints.deinit(_static.gpa);
|
|
|
|
_static.track_arena_lock.lock();
|
|
defer _static.track_arena_lock.unlock();
|
|
|
|
var it = _static.track_arenas.valueIterator();
|
|
while (it.next()) |arena| {
|
|
arena.deinit();
|
|
}
|
|
_static.track_arenas.deinit(_static.gpa);
|
|
}
|
|
|
|
// This can be resolved at comptime so *perhaps it does affect optimiazation
|
|
pub fn callHandlerIfExist(comptime fn_name: []const u8, e: anytype, arena: Allocator, ctx: *Context, r: Request) anyerror!void {
|
|
const EndPoint = @TypeOf(e.*);
|
|
if (@hasDecl(EndPoint, fn_name)) {
|
|
return @field(EndPoint, fn_name)(e, arena, ctx, r);
|
|
}
|
|
zap.log.debug(
|
|
"Unhandled `{s}` {s} request ({s} not implemented in {s})",
|
|
.{ r.method orelse "<unknown>", r.path orelse "", fn_name, @typeName(Endpoint) },
|
|
);
|
|
r.setStatus(.method_not_allowed);
|
|
try r.sendBody("405 - method not allowed\r\n");
|
|
return;
|
|
}
|
|
|
|
pub fn get_arena() !*ArenaAllocator {
|
|
const thread_id = std.Thread.getCurrentId();
|
|
_static.track_arena_lock.lockShared();
|
|
if (_static.track_arenas.getPtr(thread_id)) |arena| {
|
|
_static.track_arena_lock.unlockShared();
|
|
return arena;
|
|
} else {
|
|
_static.track_arena_lock.unlockShared();
|
|
_static.track_arena_lock.lock();
|
|
defer _static.track_arena_lock.unlock();
|
|
const arena = ArenaAllocator.init(_static.gpa);
|
|
try _static.track_arenas.put(_static.gpa, thread_id, arena);
|
|
return _static.track_arenas.getPtr(thread_id).?;
|
|
}
|
|
}
|
|
|
|
pub fn register(endpoint: anytype) !void {
|
|
for (_static.endpoints.items) |other| {
|
|
if (std.mem.eql(
|
|
u8,
|
|
other.path,
|
|
endpoint.path,
|
|
)) {
|
|
return zap.Endpoint.ListenerError.EndpointPathShadowError;
|
|
}
|
|
}
|
|
const EndpointType = @typeInfo(@TypeOf(endpoint)).pointer.child;
|
|
Endpoint.checkEndpointType(EndpointType);
|
|
const bound = try _static.gpa.create(Endpoint.Bind(EndpointType));
|
|
bound.* = Endpoint.init(EndpointType, endpoint);
|
|
try _static.endpoints.append(_static.gpa, &bound.interface);
|
|
|
|
std.mem.sort(*Endpoint.Interface, _static.endpoints.items, {}, lessThanFnEndpointInterface);
|
|
}
|
|
|
|
fn lessThanFnEndpointInterface(_: void, lh: *Endpoint.Interface, rh: *Endpoint.Interface) bool {
|
|
return lh.path.len > rh.path.len;
|
|
}
|
|
|
|
pub fn listen(l: ListenerSettings) !void {
|
|
_static.listener = HttpListener.init(.{
|
|
.interface = l.interface,
|
|
.port = l.port,
|
|
.public_folder = l.public_folder,
|
|
.max_clients = l.max_clients,
|
|
.max_body_size = l.max_body_size,
|
|
.timeout = l.timeout,
|
|
.tls = l.tls,
|
|
|
|
.on_request = onRequest,
|
|
});
|
|
try _static.listener.listen();
|
|
}
|
|
|
|
fn sanitizePath(path: []const u8) []const u8 {
|
|
if (path.len > 1 and path[path.len - 1] == '/') {
|
|
return path[0..(path.len - 1)];
|
|
}
|
|
return path;
|
|
}
|
|
|
|
fn onRequest(r: Request) !void {
|
|
if (r.path) |unsan_p| {
|
|
const p = sanitizePath(unsan_p);
|
|
for (_static.endpoints.items) |interface| {
|
|
if (std.mem.eql(u8, p, interface.path)) {
|
|
return interface.call(interface, r) catch |err| {
|
|
// if error is not dealt with in the interface, e.g.
|
|
// if error strategy is .raise:
|
|
if (_static.unhandled_error) |error_cb| {
|
|
error_cb(_static.context, r, err);
|
|
} else {
|
|
zap.log.err(
|
|
"App.Endpoint onRequest error {} in endpoint interface {}\n",
|
|
.{ err, interface },
|
|
);
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
// this is basically the "not found" handler
|
|
if (_static.unhandled_request) |user_handler| {
|
|
var arena = try get_arena();
|
|
user_handler(_static.context, arena.allocator(), r) catch |err| {
|
|
switch (_static.opts.default_error_strategy) {
|
|
.raise => if (_static.unhandled_error) |error_cb| {
|
|
error_cb(_static.context, r, err);
|
|
} else {
|
|
zap.Logging.on_uncaught_error("App on_request", err);
|
|
},
|
|
.log_to_response => return r.sendError(err, if (@errorReturnTrace()) |t| t.* else null, 505),
|
|
.log_to_console => zap.log.err("Error in {} {s} : {}", .{ App, r.method orelse "(no method)", err }),
|
|
}
|
|
};
|
|
}
|
|
}
|
|
};
|
|
}
|