Unverified Commit 05e400a3 authored by Luca Boccassi's avatar Luca Boccassi Committed by GitHub

Merge pull request #3203 from sigiesec/update-wepoll

Problem: wepoll outdated
parents d9ade476 9440f4e5
https://github.com/piscisaureus/wepoll/tree/v1.5.0 https://github.com/piscisaureus/wepoll/tree/v1.5.2
...@@ -126,13 +126,11 @@ WEPOLL_EXPORT int epoll_wait(HANDLE ephnd, ...@@ -126,13 +126,11 @@ WEPOLL_EXPORT int epoll_wait(HANDLE ephnd,
#pragma clang diagnostic ignored "-Wreserved-id-macro" #pragma clang diagnostic ignored "-Wreserved-id-macro"
#endif #endif
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0600 #ifdef _WIN32_WINNT
#undef _WIN32_WINNT #undef _WIN32_WINNT
#endif #endif
#ifndef _WIN32_WINNT
#define _WIN32_WINNT 0x0600 #define _WIN32_WINNT 0x0600
#endif
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic pop #pragma clang diagnostic pop
...@@ -172,10 +170,7 @@ typedef NTSTATUS* PNTSTATUS; ...@@ -172,10 +170,7 @@ typedef NTSTATUS* PNTSTATUS;
#endif #endif
typedef struct _IO_STATUS_BLOCK { typedef struct _IO_STATUS_BLOCK {
union {
NTSTATUS Status; NTSTATUS Status;
PVOID Pointer;
};
ULONG_PTR Information; ULONG_PTR Information;
} IO_STATUS_BLOCK, *PIO_STATUS_BLOCK; } IO_STATUS_BLOCK, *PIO_STATUS_BLOCK;
...@@ -189,6 +184,9 @@ typedef struct _LSA_UNICODE_STRING { ...@@ -189,6 +184,9 @@ typedef struct _LSA_UNICODE_STRING {
PWSTR Buffer; PWSTR Buffer;
} LSA_UNICODE_STRING, *PLSA_UNICODE_STRING, UNICODE_STRING, *PUNICODE_STRING; } LSA_UNICODE_STRING, *PLSA_UNICODE_STRING, UNICODE_STRING, *PUNICODE_STRING;
#define RTL_CONSTANT_STRING(s) \
{ sizeof(s) - sizeof((s)[0]), sizeof(s), s }
typedef struct _OBJECT_ATTRIBUTES { typedef struct _OBJECT_ATTRIBUTES {
ULONG Length; ULONG Length;
HANDLE RootDirectory; HANDLE RootDirectory;
...@@ -198,7 +196,29 @@ typedef struct _OBJECT_ATTRIBUTES { ...@@ -198,7 +196,29 @@ typedef struct _OBJECT_ATTRIBUTES {
PVOID SecurityQualityOfService; PVOID SecurityQualityOfService;
} OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES;
#define NTDLL_IMPORT_LIST(X) \ #define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \
{ sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL }
#ifndef FILE_OPEN
#define FILE_OPEN 0x00000001UL
#endif
#define NT_NTDLL_IMPORT_LIST(X) \
X(NTSTATUS, \
NTAPI, \
NtCreateFile, \
(PHANDLE FileHandle, \
ACCESS_MASK DesiredAccess, \
POBJECT_ATTRIBUTES ObjectAttributes, \
PIO_STATUS_BLOCK IoStatusBlock, \
PLARGE_INTEGER AllocationSize, \
ULONG FileAttributes, \
ULONG ShareAccess, \
ULONG CreateDisposition, \
ULONG CreateOptions, \
PVOID EaBuffer, \
ULONG EaLength)) \
\
X(NTSTATUS, \ X(NTSTATUS, \
NTAPI, \ NTAPI, \
NtDeviceIoControlFile, \ NtDeviceIoControlFile, \
...@@ -235,7 +255,7 @@ typedef struct _OBJECT_ATTRIBUTES { ...@@ -235,7 +255,7 @@ typedef struct _OBJECT_ATTRIBUTES {
#define X(return_type, attributes, name, parameters) \ #define X(return_type, attributes, name, parameters) \
WEPOLL_INTERNAL_VAR return_type(attributes* name) parameters; WEPOLL_INTERNAL_VAR return_type(attributes* name) parameters;
NTDLL_IMPORT_LIST(X) NT_NTDLL_IMPORT_LIST(X)
#undef X #undef X
#include <assert.h> #include <assert.h>
...@@ -259,12 +279,6 @@ typedef intptr_t ssize_t; ...@@ -259,12 +279,6 @@ typedef intptr_t ssize_t;
#define inline __inline #define inline __inline
#endif #endif
/* Polyfill `static_assert` for some versions of clang and gcc. */
#if (defined(__clang__) || defined(__GNUC__)) && !defined(static_assert)
#define static_assert(condition, message) typedef __attribute__( \
(__unused__)) int __static_assert_##__LINE__[(condition) ? 1 : -1]
#endif
/* clang-format off */ /* clang-format off */
#define AFD_POLL_RECEIVE 0x0001 #define AFD_POLL_RECEIVE 0x0001
#define AFD_POLL_RECEIVE_EXPEDITED 0x0002 #define AFD_POLL_RECEIVE_EXPEDITED 0x0002
...@@ -290,12 +304,10 @@ typedef struct _AFD_POLL_INFO { ...@@ -290,12 +304,10 @@ typedef struct _AFD_POLL_INFO {
AFD_POLL_HANDLE_INFO Handles[1]; AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO; } AFD_POLL_INFO, *PAFD_POLL_INFO;
WEPOLL_INTERNAL int afd_global_init(void); WEPOLL_INTERNAL int afd_create_helper_handle(HANDLE iocp,
HANDLE* afd_helper_handle_out);
WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp,
SOCKET* driver_socket_out);
WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket, WEPOLL_INTERNAL int afd_poll(HANDLE afd_helper_handle,
AFD_POLL_INFO* poll_info, AFD_POLL_INFO* poll_info,
OVERLAPPED* overlapped); OVERLAPPED* overlapped);
...@@ -316,138 +328,69 @@ WEPOLL_INTERNAL void err_set_win_error(DWORD error); ...@@ -316,138 +328,69 @@ WEPOLL_INTERNAL void err_set_win_error(DWORD error);
WEPOLL_INTERNAL int err_check_handle(HANDLE handle); WEPOLL_INTERNAL int err_check_handle(HANDLE handle);
WEPOLL_INTERNAL int ws_global_init(void); WEPOLL_INTERNAL int ws_global_init(void);
WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket); WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket);
WEPOLL_INTERNAL int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out,
size_t* infos_count_out);
#define IOCTL_AFD_POLL 0x00012024 #define IOCTL_AFD_POLL 0x00012024
/* clang-format off */ static UNICODE_STRING afd__helper_name =
static const GUID _AFD_PROVIDER_GUID_LIST[] = { RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll");
/* MSAFD Tcpip [TCP+UDP+RAW / IP] */
{0xe70f1aa0, 0xab8b, 0x11cf,
{0x8c, 0xa3, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}},
/* MSAFD Tcpip [TCP+UDP+RAW / IPv6] */
{0xf9eab0c0, 0x26d4, 0x11d0,
{0xbb, 0xbf, 0x00, 0xaa, 0x00, 0x6c, 0x34, 0xe4}},
/* MSAFD RfComm [Bluetooth] */
{0x9fc48064, 0x7298, 0x43e4,
{0xb7, 0xbd, 0x18, 0x1f, 0x20, 0x89, 0x79, 0x2a}},
/* MSAFD Irda [IrDA] */
{0x3972523d, 0x2af1, 0x11d1,
{0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}};
/* clang-format on */
static const int _AFD_ANY_PROTOCOL = -1;
/* This protocol info record is used by afd_create_driver_socket() to create
* sockets that can be used as the first argument to afd_poll(). It is
* populated on startup by afd_global_init(). */
static WSAPROTOCOL_INFOW _afd_driver_socket_template;
static const WSAPROTOCOL_INFOW* _afd_find_protocol_info(
const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) {
size_t i, j;
for (i = 0; i < infos_count; i++) {
const WSAPROTOCOL_INFOW* info = &infos[i];
/* Apply protocol id filter. */
if (protocol_id != _AFD_ANY_PROTOCOL && protocol_id != info->iProtocol)
continue;
/* Filter out non-MSAFD protocols. */ static OBJECT_ATTRIBUTES afd__helper_attributes =
for (j = 0; j < array_count(_AFD_PROVIDER_GUID_LIST); j++) { RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__helper_name, 0);
if (memcmp(&info->ProviderId,
&_AFD_PROVIDER_GUID_LIST[j],
sizeof info->ProviderId) == 0)
return info;
}
}
return NULL; /* Not found. */
}
int afd_global_init(void) {
WSAPROTOCOL_INFOW* infos;
size_t infos_count;
const WSAPROTOCOL_INFOW* afd_info;
/* Load the winsock catalog. */
if (ws_get_protocol_catalog(&infos, &infos_count) < 0)
return -1;
/* Find a WSAPROTOCOL_INFOW structure that we can use to create an MSAFD
* socket. Preferentially we pick a UDP socket, otherwise try TCP or any
* other type. */
for (;;) {
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_UDP);
if (afd_info != NULL)
break;
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_TCP);
if (afd_info != NULL)
break;
afd_info = _afd_find_protocol_info(infos, infos_count, _AFD_ANY_PROTOCOL);
if (afd_info != NULL)
break;
free(infos);
return_set_error(-1, WSAENETDOWN); /* No suitable protocol found. */
}
/* Copy found protocol information from the catalog to a static buffer. */
_afd_driver_socket_template = *afd_info;
free(infos);
return 0;
}
int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) { int afd_create_helper_handle(HANDLE iocp, HANDLE* afd_helper_handle_out) {
SOCKET socket; HANDLE afd_helper_handle;
IO_STATUS_BLOCK iosb;
NTSTATUS status;
socket = WSASocketW(_afd_driver_socket_template.iAddressFamily, /* By opening \Device\Afd without specifying any extended attributes, we'll
_afd_driver_socket_template.iSocketType, * get a handle that lets us talk to the AFD driver, but that doesn't have an
_afd_driver_socket_template.iProtocol, * associated endpoint (so it's not a socket). */
&_afd_driver_socket_template, status = NtCreateFile(&afd_helper_handle,
SYNCHRONIZE,
&afd__helper_attributes,
&iosb,
NULL,
0, 0,
WSA_FLAG_OVERLAPPED); FILE_SHARE_READ | FILE_SHARE_WRITE,
if (socket == INVALID_SOCKET) FILE_OPEN,
return_map_error(-1); 0,
NULL,
0);
if (status != STATUS_SUCCESS)
return_set_error(-1, RtlNtStatusToDosError(status));
/* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */ if (CreateIoCompletionPort(afd_helper_handle, iocp, 0, 0) == NULL)
if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0))
goto error; goto error;
if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) if (!SetFileCompletionNotificationModes(afd_helper_handle,
FILE_SKIP_SET_EVENT_ON_HANDLE))
goto error; goto error;
*driver_socket_out = socket; *afd_helper_handle_out = afd_helper_handle;
return 0; return 0;
error:; error:
DWORD error = GetLastError(); CloseHandle(afd_helper_handle);
closesocket(socket); return_map_error(-1);
return_set_error(-1, error);
} }
int afd_poll(SOCKET driver_socket, int afd_poll(HANDLE afd_helper_handle,
AFD_POLL_INFO* poll_info, AFD_POLL_INFO* poll_info,
OVERLAPPED* overlapped) { OVERLAPPED* overlapped) {
IO_STATUS_BLOCK iosb; IO_STATUS_BLOCK* iosb;
IO_STATUS_BLOCK* iosb_ptr; HANDLE event;
HANDLE event = NULL;
void* apc_context; void* apc_context;
NTSTATUS status; NTSTATUS status;
if (overlapped != NULL) { /* Blocking operation is not supported. */
/* Overlapped operation. */ assert(overlapped != NULL);
iosb_ptr = (IO_STATUS_BLOCK*) &overlapped->Internal;
iosb = (IO_STATUS_BLOCK*) &overlapped->Internal;
event = overlapped->hEvent; event = overlapped->hEvent;
/* Do not report iocp completion if hEvent is tagged. */ /* Do what other windows APIs would do: if hEvent has it's lowest bit set,
* don't post a completion to the completion port. */
if ((uintptr_t) event & 1) { if ((uintptr_t) event & 1) {
event = (HANDLE)((uintptr_t) event & ~(uintptr_t) 1); event = (HANDLE)((uintptr_t) event & ~(uintptr_t) 1);
apc_context = NULL; apc_context = NULL;
...@@ -455,45 +398,18 @@ int afd_poll(SOCKET driver_socket, ...@@ -455,45 +398,18 @@ int afd_poll(SOCKET driver_socket,
apc_context = overlapped; apc_context = overlapped;
} }
} else { iosb->Status = STATUS_PENDING;
/* Blocking operation. */ status = NtDeviceIoControlFile(afd_helper_handle,
iosb_ptr = &iosb;
event = CreateEventW(NULL, FALSE, FALSE, NULL);
if (event == NULL)
return_map_error(-1);
apc_context = NULL;
}
iosb_ptr->Status = STATUS_PENDING;
status = NtDeviceIoControlFile((HANDLE) driver_socket,
event, event,
NULL, NULL,
apc_context, apc_context,
iosb_ptr, iosb,
IOCTL_AFD_POLL, IOCTL_AFD_POLL,
poll_info, poll_info,
sizeof *poll_info, sizeof *poll_info,
poll_info, poll_info,
sizeof *poll_info); sizeof *poll_info);
if (overlapped == NULL) {
/* If this is a blocking operation, wait for the event to become signaled,
* and then grab the real status from the io status block. */
if (status == STATUS_PENDING) {
DWORD r = WaitForSingleObject(event, INFINITE);
if (r == WAIT_FAILED) {
DWORD error = GetLastError();
CloseHandle(event);
return_set_error(-1, error);
}
status = iosb_ptr->Status;
}
CloseHandle(event);
}
if (status == STATUS_SUCCESS) if (status == STATUS_SUCCESS)
return 0; return 0;
else if (status == STATUS_PENDING) else if (status == STATUS_PENDING)
...@@ -502,7 +418,7 @@ int afd_poll(SOCKET driver_socket, ...@@ -502,7 +418,7 @@ int afd_poll(SOCKET driver_socket,
return_set_error(-1, RtlNtStatusToDosError(status)); return_set_error(-1, RtlNtStatusToDosError(status));
} }
WEPOLL_INTERNAL int api_global_init(void); WEPOLL_INTERNAL int epoll_global_init(void);
WEPOLL_INTERNAL int init(void); WEPOLL_INTERNAL int init(void);
...@@ -544,7 +460,8 @@ WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group); ...@@ -544,7 +460,8 @@ WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group);
WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node( WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node(
queue_node_t* queue_node); queue_node_t* queue_node);
WEPOLL_INTERNAL SOCKET poll_group_get_socket(poll_group_t* poll_group); WEPOLL_INTERNAL HANDLE
poll_group_get_afd_helper_handle(poll_group_t* poll_group);
/* N.b.: the tree functions do not set errno or LastError when they fail. Each /* N.b.: the tree functions do not set errno or LastError when they fail. Each
* of the API functions has at most one failure mode. It is up to the caller to * of the API functions has at most one failure mode. It is up to the caller to
...@@ -621,7 +538,7 @@ WEPOLL_INTERNAL tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state); ...@@ -621,7 +538,7 @@ WEPOLL_INTERNAL tree_node_t* sock_state_to_tree_node(sock_state_t* sock_state);
*/ */
typedef struct reflock { typedef struct reflock {
uint32_t state; volatile long state; /* 32-bit Interlocked APIs operate on `long` values. */
} reflock_t; } reflock_t;
WEPOLL_INTERNAL int reflock_global_init(void); WEPOLL_INTERNAL int reflock_global_init(void);
...@@ -702,19 +619,19 @@ WEPOLL_INTERNAL void port_add_deleted_socket(port_state_t* port_state, ...@@ -702,19 +619,19 @@ WEPOLL_INTERNAL void port_add_deleted_socket(port_state_t* port_state,
WEPOLL_INTERNAL void port_remove_deleted_socket(port_state_t* port_state, WEPOLL_INTERNAL void port_remove_deleted_socket(port_state_t* port_state,
sock_state_t* sock_state); sock_state_t* sock_state);
static ts_tree_t _epoll_handle_tree; static ts_tree_t epoll__handle_tree;
static inline port_state_t* _handle_tree_node_to_port( static inline port_state_t* epoll__handle_tree_node_to_port(
ts_tree_node_t* tree_node) { ts_tree_node_t* tree_node) {
return container_of(tree_node, port_state_t, handle_tree_node); return container_of(tree_node, port_state_t, handle_tree_node);
} }
int api_global_init(void) { int epoll_global_init(void) {
ts_tree_init(&_epoll_handle_tree); ts_tree_init(&epoll__handle_tree);
return 0; return 0;
} }
static HANDLE _epoll_create(void) { static HANDLE epoll__create(void) {
port_state_t* port_state; port_state_t* port_state;
HANDLE ephnd; HANDLE ephnd;
...@@ -725,7 +642,7 @@ static HANDLE _epoll_create(void) { ...@@ -725,7 +642,7 @@ static HANDLE _epoll_create(void) {
if (port_state == NULL) if (port_state == NULL)
return NULL; return NULL;
if (ts_tree_add(&_epoll_handle_tree, if (ts_tree_add(&epoll__handle_tree,
&port_state->handle_tree_node, &port_state->handle_tree_node,
(uintptr_t) ephnd) < 0) { (uintptr_t) ephnd) < 0) {
/* This should never happen. */ /* This should never happen. */
...@@ -740,14 +657,14 @@ HANDLE epoll_create(int size) { ...@@ -740,14 +657,14 @@ HANDLE epoll_create(int size) {
if (size <= 0) if (size <= 0)
return_set_error(NULL, ERROR_INVALID_PARAMETER); return_set_error(NULL, ERROR_INVALID_PARAMETER);
return _epoll_create(); return epoll__create();
} }
HANDLE epoll_create1(int flags) { HANDLE epoll_create1(int flags) {
if (flags != 0) if (flags != 0)
return_set_error(NULL, ERROR_INVALID_PARAMETER); return_set_error(NULL, ERROR_INVALID_PARAMETER);
return _epoll_create(); return epoll__create();
} }
int epoll_close(HANDLE ephnd) { int epoll_close(HANDLE ephnd) {
...@@ -757,13 +674,13 @@ int epoll_close(HANDLE ephnd) { ...@@ -757,13 +674,13 @@ int epoll_close(HANDLE ephnd) {
if (init() < 0) if (init() < 0)
return -1; return -1;
tree_node = ts_tree_del_and_ref(&_epoll_handle_tree, (uintptr_t) ephnd); tree_node = ts_tree_del_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
if (tree_node == NULL) { if (tree_node == NULL) {
err_set_win_error(ERROR_INVALID_PARAMETER); err_set_win_error(ERROR_INVALID_PARAMETER);
goto err; goto err;
} }
port_state = _handle_tree_node_to_port(tree_node); port_state = epoll__handle_tree_node_to_port(tree_node);
port_close(port_state); port_close(port_state);
ts_tree_node_unref_and_destroy(tree_node); ts_tree_node_unref_and_destroy(tree_node);
...@@ -783,13 +700,13 @@ int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event* ev) { ...@@ -783,13 +700,13 @@ int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event* ev) {
if (init() < 0) if (init() < 0)
return -1; return -1;
tree_node = ts_tree_find_and_ref(&_epoll_handle_tree, (uintptr_t) ephnd); tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
if (tree_node == NULL) { if (tree_node == NULL) {
err_set_win_error(ERROR_INVALID_PARAMETER); err_set_win_error(ERROR_INVALID_PARAMETER);
goto err; goto err;
} }
port_state = _handle_tree_node_to_port(tree_node); port_state = epoll__handle_tree_node_to_port(tree_node);
r = port_ctl(port_state, op, sock, ev); r = port_ctl(port_state, op, sock, ev);
ts_tree_node_unref(tree_node); ts_tree_node_unref(tree_node);
...@@ -821,13 +738,13 @@ int epoll_wait(HANDLE ephnd, ...@@ -821,13 +738,13 @@ int epoll_wait(HANDLE ephnd,
if (init() < 0) if (init() < 0)
return -1; return -1;
tree_node = ts_tree_find_and_ref(&_epoll_handle_tree, (uintptr_t) ephnd); tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t) ephnd);
if (tree_node == NULL) { if (tree_node == NULL) {
err_set_win_error(ERROR_INVALID_PARAMETER); err_set_win_error(ERROR_INVALID_PARAMETER);
goto err; goto err;
} }
port_state = _handle_tree_node_to_port(tree_node); port_state = epoll__handle_tree_node_to_port(tree_node);
num_events = port_wait(port_state, events, maxevents, timeout); num_events = port_wait(port_state, events, maxevents, timeout);
ts_tree_node_unref(tree_node); ts_tree_node_unref(tree_node);
...@@ -947,7 +864,7 @@ err: ...@@ -947,7 +864,7 @@ err:
X(WSASYSNOTREADY, ENETDOWN) \ X(WSASYSNOTREADY, ENETDOWN) \
X(WSAVERNOTSUPPORTED, ENOSYS) X(WSAVERNOTSUPPORTED, ENOSYS)
static errno_t _err_map_win_error_to_errno(DWORD error) { static errno_t err__map_win_error_to_errno(DWORD error) {
switch (error) { switch (error) {
#define X(error_sym, errno_sym) \ #define X(error_sym, errno_sym) \
case error_sym: \ case error_sym: \
...@@ -959,12 +876,12 @@ static errno_t _err_map_win_error_to_errno(DWORD error) { ...@@ -959,12 +876,12 @@ static errno_t _err_map_win_error_to_errno(DWORD error) {
} }
void err_map_win_error(void) { void err_map_win_error(void) {
errno = _err_map_win_error_to_errno(GetLastError()); errno = err__map_win_error_to_errno(GetLastError());
} }
void err_set_win_error(DWORD error) { void err_set_win_error(DWORD error) {
SetLastError(error); SetLastError(error);
errno = _err_map_win_error_to_errno(error); errno = err__map_win_error_to_errno(error);
} }
int err_check_handle(HANDLE handle) { int err_check_handle(HANDLE handle) {
...@@ -981,10 +898,10 @@ int err_check_handle(HANDLE handle) { ...@@ -981,10 +898,10 @@ int err_check_handle(HANDLE handle) {
return 0; return 0;
} }
static bool _initialized = false; static bool init__done = false;
static INIT_ONCE _once = INIT_ONCE_STATIC_INIT; static INIT_ONCE init__once = INIT_ONCE_STATIC_INIT;
static BOOL CALLBACK _init_once_callback(INIT_ONCE* once, static BOOL CALLBACK init__once_callback(INIT_ONCE* once,
void* parameter, void* parameter,
void** context) { void** context) {
unused_var(once); unused_var(once);
...@@ -992,17 +909,17 @@ static BOOL CALLBACK _init_once_callback(INIT_ONCE* once, ...@@ -992,17 +909,17 @@ static BOOL CALLBACK _init_once_callback(INIT_ONCE* once,
unused_var(context); unused_var(context);
/* N.b. that initialization order matters here. */ /* N.b. that initialization order matters here. */
if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 || if (ws_global_init() < 0 || nt_global_init() < 0 ||
reflock_global_init() < 0 || api_global_init() < 0) reflock_global_init() < 0 || epoll_global_init() < 0)
return FALSE; return FALSE;
_initialized = true; init__done = true;
return TRUE; return TRUE;
} }
int init(void) { int init(void) {
if (!_initialized && if (!init__done &&
!InitOnceExecuteOnce(&_once, _init_once_callback, NULL, NULL)) !InitOnceExecuteOnce(&init__once, init__once_callback, NULL, NULL))
return -1; /* LastError and errno aren't touched InitOnceExecuteOnce. */ return -1; /* LastError and errno aren't touched InitOnceExecuteOnce. */
return 0; return 0;
...@@ -1010,7 +927,7 @@ int init(void) { ...@@ -1010,7 +927,7 @@ int init(void) {
#define X(return_type, attributes, name, parameters) \ #define X(return_type, attributes, name, parameters) \
WEPOLL_INTERNAL return_type(attributes* name) parameters = NULL; WEPOLL_INTERNAL return_type(attributes* name) parameters = NULL;
NTDLL_IMPORT_LIST(X) NT_NTDLL_IMPORT_LIST(X)
#undef X #undef X
int nt_global_init(void) { int nt_global_init(void) {
...@@ -1024,7 +941,7 @@ int nt_global_init(void) { ...@@ -1024,7 +941,7 @@ int nt_global_init(void) {
name = (return_type(attributes*) parameters) GetProcAddress(ntdll, #name); \ name = (return_type(attributes*) parameters) GetProcAddress(ntdll, #name); \
if (name == NULL) \ if (name == NULL) \
return -1; return -1;
NTDLL_IMPORT_LIST(X) NT_NTDLL_IMPORT_LIST(X)
#undef X #undef X
return 0; return 0;
...@@ -1032,16 +949,16 @@ int nt_global_init(void) { ...@@ -1032,16 +949,16 @@ int nt_global_init(void) {
#include <string.h> #include <string.h>
static const size_t _POLL_GROUP_MAX_GROUP_SIZE = 32; static const size_t POLL_GROUP__MAX_GROUP_SIZE = 32;
typedef struct poll_group { typedef struct poll_group {
port_state_t* port_state; port_state_t* port_state;
queue_node_t queue_node; queue_node_t queue_node;
SOCKET socket; HANDLE afd_helper_handle;
size_t group_size; size_t group_size;
} poll_group_t; } poll_group_t;
static poll_group_t* _poll_group_new(port_state_t* port_state) { static poll_group_t* poll_group__new(port_state_t* port_state) {
poll_group_t* poll_group = malloc(sizeof *poll_group); poll_group_t* poll_group = malloc(sizeof *poll_group);
if (poll_group == NULL) if (poll_group == NULL)
return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
...@@ -1051,7 +968,8 @@ static poll_group_t* _poll_group_new(port_state_t* port_state) { ...@@ -1051,7 +968,8 @@ static poll_group_t* _poll_group_new(port_state_t* port_state) {
queue_node_init(&poll_group->queue_node); queue_node_init(&poll_group->queue_node);
poll_group->port_state = port_state; poll_group->port_state = port_state;
if (afd_create_driver_socket(port_state->iocp, &poll_group->socket) < 0) { if (afd_create_helper_handle(port_state->iocp,
&poll_group->afd_helper_handle) < 0) {
free(poll_group); free(poll_group);
return NULL; return NULL;
} }
...@@ -1063,7 +981,7 @@ static poll_group_t* _poll_group_new(port_state_t* port_state) { ...@@ -1063,7 +981,7 @@ static poll_group_t* _poll_group_new(port_state_t* port_state) {
void poll_group_delete(poll_group_t* poll_group) { void poll_group_delete(poll_group_t* poll_group) {
assert(poll_group->group_size == 0); assert(poll_group->group_size == 0);
closesocket(poll_group->socket); CloseHandle(poll_group->afd_helper_handle);
queue_remove(&poll_group->queue_node); queue_remove(&poll_group->queue_node);
free(poll_group); free(poll_group);
} }
...@@ -1072,8 +990,8 @@ poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) { ...@@ -1072,8 +990,8 @@ poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) {
return container_of(queue_node, poll_group_t, queue_node); return container_of(queue_node, poll_group_t, queue_node);
} }
SOCKET poll_group_get_socket(poll_group_t* poll_group) { HANDLE poll_group_get_afd_helper_handle(poll_group_t* poll_group) {
return poll_group->socket; return poll_group->afd_helper_handle;
} }
poll_group_t* poll_group_acquire(port_state_t* port_state) { poll_group_t* poll_group_acquire(port_state_t* port_state) {
...@@ -1084,12 +1002,12 @@ poll_group_t* poll_group_acquire(port_state_t* port_state) { ...@@ -1084,12 +1002,12 @@ poll_group_t* poll_group_acquire(port_state_t* port_state) {
: NULL; : NULL;
if (poll_group == NULL || if (poll_group == NULL ||
poll_group->group_size >= _POLL_GROUP_MAX_GROUP_SIZE) poll_group->group_size >= POLL_GROUP__MAX_GROUP_SIZE)
poll_group = _poll_group_new(port_state); poll_group = poll_group__new(port_state);
if (poll_group == NULL) if (poll_group == NULL)
return NULL; return NULL;
if (++poll_group->group_size == _POLL_GROUP_MAX_GROUP_SIZE) if (++poll_group->group_size == POLL_GROUP__MAX_GROUP_SIZE)
queue_move_first(&port_state->poll_group_queue, &poll_group->queue_node); queue_move_first(&port_state->poll_group_queue, &poll_group->queue_node);
return poll_group; return poll_group;
...@@ -1099,7 +1017,7 @@ void poll_group_release(poll_group_t* poll_group) { ...@@ -1099,7 +1017,7 @@ void poll_group_release(poll_group_t* poll_group) {
port_state_t* port_state = poll_group->port_state; port_state_t* port_state = poll_group->port_state;
poll_group->group_size--; poll_group->group_size--;
assert(poll_group->group_size < _POLL_GROUP_MAX_GROUP_SIZE); assert(poll_group->group_size < POLL_GROUP__MAX_GROUP_SIZE);
queue_move_last(&port_state->poll_group_queue, &poll_group->queue_node); queue_move_last(&port_state->poll_group_queue, &poll_group->queue_node);
...@@ -1108,7 +1026,7 @@ void poll_group_release(poll_group_t* poll_group) { ...@@ -1108,7 +1026,7 @@ void poll_group_release(poll_group_t* poll_group) {
#define PORT__MAX_ON_STACK_COMPLETIONS 256 #define PORT__MAX_ON_STACK_COMPLETIONS 256
static port_state_t* _port_alloc(void) { static port_state_t* port__alloc(void) {
port_state_t* port_state = malloc(sizeof *port_state); port_state_t* port_state = malloc(sizeof *port_state);
if (port_state == NULL) if (port_state == NULL)
return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
...@@ -1116,12 +1034,12 @@ static port_state_t* _port_alloc(void) { ...@@ -1116,12 +1034,12 @@ static port_state_t* _port_alloc(void) {
return port_state; return port_state;
} }
static void _port_free(port_state_t* port) { static void port__free(port_state_t* port) {
assert(port != NULL); assert(port != NULL);
free(port); free(port);
} }
static HANDLE _port_create_iocp(void) { static HANDLE port__create_iocp(void) {
HANDLE iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); HANDLE iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (iocp == NULL) if (iocp == NULL)
return_map_error(NULL); return_map_error(NULL);
...@@ -1133,11 +1051,11 @@ port_state_t* port_new(HANDLE* iocp_out) { ...@@ -1133,11 +1051,11 @@ port_state_t* port_new(HANDLE* iocp_out) {
port_state_t* port_state; port_state_t* port_state;
HANDLE iocp; HANDLE iocp;
port_state = _port_alloc(); port_state = port__alloc();
if (port_state == NULL) if (port_state == NULL)
goto err1; goto err1;
iocp = _port_create_iocp(); iocp = port__create_iocp();
if (iocp == NULL) if (iocp == NULL)
goto err2; goto err2;
...@@ -1155,12 +1073,12 @@ port_state_t* port_new(HANDLE* iocp_out) { ...@@ -1155,12 +1073,12 @@ port_state_t* port_new(HANDLE* iocp_out) {
return port_state; return port_state;
err2: err2:
_port_free(port_state); port__free(port_state);
err1: err1:
return NULL; return NULL;
} }
static int _port_close_iocp(port_state_t* port_state) { static int port__close_iocp(port_state_t* port_state) {
HANDLE iocp = port_state->iocp; HANDLE iocp = port_state->iocp;
port_state->iocp = NULL; port_state->iocp = NULL;
...@@ -1174,7 +1092,7 @@ int port_close(port_state_t* port_state) { ...@@ -1174,7 +1092,7 @@ int port_close(port_state_t* port_state) {
int result; int result;
EnterCriticalSection(&port_state->lock); EnterCriticalSection(&port_state->lock);
result = _port_close_iocp(port_state); result = port__close_iocp(port_state);
LeaveCriticalSection(&port_state->lock); LeaveCriticalSection(&port_state->lock);
return result; return result;
...@@ -1206,12 +1124,12 @@ int port_delete(port_state_t* port_state) { ...@@ -1206,12 +1124,12 @@ int port_delete(port_state_t* port_state) {
DeleteCriticalSection(&port_state->lock); DeleteCriticalSection(&port_state->lock);
_port_free(port_state); port__free(port_state);
return 0; return 0;
} }
static int _port_update_events(port_state_t* port_state) { static int port__update_events(port_state_t* port_state) {
queue_t* sock_update_queue = &port_state->sock_update_queue; queue_t* sock_update_queue = &port_state->sock_update_queue;
/* Walk the queue, submitting new poll requests for every socket that needs /* Walk the queue, submitting new poll requests for every socket that needs
...@@ -1229,12 +1147,12 @@ static int _port_update_events(port_state_t* port_state) { ...@@ -1229,12 +1147,12 @@ static int _port_update_events(port_state_t* port_state) {
return 0; return 0;
} }
static void _port_update_events_if_polling(port_state_t* port_state) { static void port__update_events_if_polling(port_state_t* port_state) {
if (port_state->active_poll_count > 0) if (port_state->active_poll_count > 0)
_port_update_events(port_state); port__update_events(port_state);
} }
static int _port_feed_events(port_state_t* port_state, static int port__feed_events(port_state_t* port_state,
struct epoll_event* epoll_events, struct epoll_event* epoll_events,
OVERLAPPED_ENTRY* iocp_events, OVERLAPPED_ENTRY* iocp_events,
DWORD iocp_event_count) { DWORD iocp_event_count) {
...@@ -1251,14 +1169,14 @@ static int _port_feed_events(port_state_t* port_state, ...@@ -1251,14 +1169,14 @@ static int _port_feed_events(port_state_t* port_state,
return epoll_event_count; return epoll_event_count;
} }
static int _port_poll(port_state_t* port_state, static int port__poll(port_state_t* port_state,
struct epoll_event* epoll_events, struct epoll_event* epoll_events,
OVERLAPPED_ENTRY* iocp_events, OVERLAPPED_ENTRY* iocp_events,
DWORD maxevents, DWORD maxevents,
DWORD timeout) { DWORD timeout) {
DWORD completion_count; DWORD completion_count;
if (_port_update_events(port_state) < 0) if (port__update_events(port_state) < 0)
return -1; return -1;
port_state->active_poll_count++; port_state->active_poll_count++;
...@@ -1279,7 +1197,7 @@ static int _port_poll(port_state_t* port_state, ...@@ -1279,7 +1197,7 @@ static int _port_poll(port_state_t* port_state,
if (!r) if (!r)
return_map_error(-1); return_map_error(-1);
return _port_feed_events( return port__feed_events(
port_state, epoll_events, iocp_events, completion_count); port_state, epoll_events, iocp_events, completion_count);
} }
...@@ -1325,7 +1243,7 @@ int port_wait(port_state_t* port_state, ...@@ -1325,7 +1243,7 @@ int port_wait(port_state_t* port_state,
for (;;) { for (;;) {
uint64_t now; uint64_t now;
result = _port_poll( result = port__poll(
port_state, events, iocp_events, (DWORD) maxevents, gqcs_timeout); port_state, events, iocp_events, (DWORD) maxevents, gqcs_timeout);
if (result < 0 || result > 0) if (result < 0 || result > 0)
break; /* Result, error, or time-out. */ break; /* Result, error, or time-out. */
...@@ -1346,7 +1264,7 @@ int port_wait(port_state_t* port_state, ...@@ -1346,7 +1264,7 @@ int port_wait(port_state_t* port_state,
gqcs_timeout = (DWORD)(due - now); gqcs_timeout = (DWORD)(due - now);
} }
_port_update_events_if_polling(port_state); port__update_events_if_polling(port_state);
LeaveCriticalSection(&port_state->lock); LeaveCriticalSection(&port_state->lock);
...@@ -1361,7 +1279,7 @@ int port_wait(port_state_t* port_state, ...@@ -1361,7 +1279,7 @@ int port_wait(port_state_t* port_state,
return -1; return -1;
} }
static int _port_ctl_add(port_state_t* port_state, static int port__ctl_add(port_state_t* port_state,
SOCKET sock, SOCKET sock,
struct epoll_event* ev) { struct epoll_event* ev) {
sock_state_t* sock_state = sock_new(port_state, sock); sock_state_t* sock_state = sock_new(port_state, sock);
...@@ -1373,12 +1291,12 @@ static int _port_ctl_add(port_state_t* port_state, ...@@ -1373,12 +1291,12 @@ static int _port_ctl_add(port_state_t* port_state,
return -1; return -1;
} }
_port_update_events_if_polling(port_state); port__update_events_if_polling(port_state);
return 0; return 0;
} }
static int _port_ctl_mod(port_state_t* port_state, static int port__ctl_mod(port_state_t* port_state,
SOCKET sock, SOCKET sock,
struct epoll_event* ev) { struct epoll_event* ev) {
sock_state_t* sock_state = port_find_socket(port_state, sock); sock_state_t* sock_state = port_find_socket(port_state, sock);
...@@ -1388,12 +1306,12 @@ static int _port_ctl_mod(port_state_t* port_state, ...@@ -1388,12 +1306,12 @@ static int _port_ctl_mod(port_state_t* port_state,
if (sock_set_event(port_state, sock_state, ev) < 0) if (sock_set_event(port_state, sock_state, ev) < 0)
return -1; return -1;
_port_update_events_if_polling(port_state); port__update_events_if_polling(port_state);
return 0; return 0;
} }
static int _port_ctl_del(port_state_t* port_state, SOCKET sock) { static int port__ctl_del(port_state_t* port_state, SOCKET sock) {
sock_state_t* sock_state = port_find_socket(port_state, sock); sock_state_t* sock_state = port_find_socket(port_state, sock);
if (sock_state == NULL) if (sock_state == NULL)
return -1; return -1;
...@@ -1403,17 +1321,17 @@ static int _port_ctl_del(port_state_t* port_state, SOCKET sock) { ...@@ -1403,17 +1321,17 @@ static int _port_ctl_del(port_state_t* port_state, SOCKET sock) {
return 0; return 0;
} }
static int _port_ctl_op(port_state_t* port_state, static int port__ctl_op(port_state_t* port_state,
int op, int op,
SOCKET sock, SOCKET sock,
struct epoll_event* ev) { struct epoll_event* ev) {
switch (op) { switch (op) {
case EPOLL_CTL_ADD: case EPOLL_CTL_ADD:
return _port_ctl_add(port_state, sock, ev); return port__ctl_add(port_state, sock, ev);
case EPOLL_CTL_MOD: case EPOLL_CTL_MOD:
return _port_ctl_mod(port_state, sock, ev); return port__ctl_mod(port_state, sock, ev);
case EPOLL_CTL_DEL: case EPOLL_CTL_DEL:
return _port_ctl_del(port_state, sock); return port__ctl_del(port_state, sock);
default: default:
return_set_error(-1, ERROR_INVALID_PARAMETER); return_set_error(-1, ERROR_INVALID_PARAMETER);
} }
...@@ -1426,7 +1344,7 @@ int port_ctl(port_state_t* port_state, ...@@ -1426,7 +1344,7 @@ int port_ctl(port_state_t* port_state,
int result; int result;
EnterCriticalSection(&port_state->lock); EnterCriticalSection(&port_state->lock);
result = _port_ctl_op(port_state, op, sock, ev); result = port__ctl_op(port_state, op, sock, ev);
LeaveCriticalSection(&port_state->lock); LeaveCriticalSection(&port_state->lock);
return result; return result;
...@@ -1495,7 +1413,7 @@ void queue_node_init(queue_node_t* node) { ...@@ -1495,7 +1413,7 @@ void queue_node_init(queue_node_t* node) {
node->next = node; node->next = node;
} }
static inline void _queue_detach(queue_node_t* node) { static inline void queue__detach_node(queue_node_t* node) {
node->prev->next = node->next; node->prev->next = node->next;
node->next->prev = node->prev; node->next->prev = node->prev;
} }
...@@ -1523,17 +1441,17 @@ void queue_append(queue_t* queue, queue_node_t* node) { ...@@ -1523,17 +1441,17 @@ void queue_append(queue_t* queue, queue_node_t* node) {
} }
void queue_move_first(queue_t* queue, queue_node_t* node) { void queue_move_first(queue_t* queue, queue_node_t* node) {
_queue_detach(node); queue__detach_node(node);
queue_prepend(queue, node); queue_prepend(queue, node);
} }
void queue_move_last(queue_t* queue, queue_node_t* node) { void queue_move_last(queue_t* queue, queue_node_t* node) {
_queue_detach(node); queue__detach_node(node);
queue_append(queue, node); queue_append(queue, node);
} }
void queue_remove(queue_node_t* node) { void queue_remove(queue_node_t* node) {
_queue_detach(node); queue__detach_node(node);
queue_node_init(node); queue_node_init(node);
} }
...@@ -1546,18 +1464,18 @@ bool queue_enqueued(const queue_node_t* node) { ...@@ -1546,18 +1464,18 @@ bool queue_enqueued(const queue_node_t* node) {
} }
/* clang-format off */ /* clang-format off */
static const uint32_t _REF = 0x00000001; static const long REFLOCK__REF = (long) 0x00000001;
static const uint32_t _REF_MASK = 0x0fffffff; static const long REFLOCK__REF_MASK = (long) 0x0fffffff;
static const uint32_t _DESTROY = 0x10000000; static const long REFLOCK__DESTROY = (long) 0x10000000;
static const uint32_t _DESTROY_MASK = 0xf0000000; static const long REFLOCK__DESTROY_MASK = (long) 0xf0000000;
static const uint32_t _POISON = 0x300DEAD0; static const long REFLOCK__POISON = (long) 0x300DEAD0;
/* clang-format on */ /* clang-format on */
static HANDLE _keyed_event = NULL; static HANDLE reflock__keyed_event = NULL;
int reflock_global_init(void) { int reflock_global_init(void) {
NTSTATUS status = NTSTATUS status =
NtCreateKeyedEvent(&_keyed_event, ~(ACCESS_MASK) 0, NULL, 0); NtCreateKeyedEvent(&reflock__keyed_event, ~(ACCESS_MASK) 0, NULL, 0);
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
return_set_error(-1, RtlNtStatusToDosError(status)); return_set_error(-1, RtlNtStatusToDosError(status));
return 0; return 0;
...@@ -1567,73 +1485,64 @@ void reflock_init(reflock_t* reflock) { ...@@ -1567,73 +1485,64 @@ void reflock_init(reflock_t* reflock) {
reflock->state = 0; reflock->state = 0;
} }
static void _signal_event(void* address) { static void reflock__signal_event(void* address) {
NTSTATUS status = NtReleaseKeyedEvent(_keyed_event, address, FALSE, NULL); NTSTATUS status =
NtReleaseKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
abort(); abort();
} }
static void _await_event(void* address) { static void reflock__await_event(void* address) {
NTSTATUS status = NtWaitForKeyedEvent(_keyed_event, address, FALSE, NULL); NTSTATUS status =
NtWaitForKeyedEvent(reflock__keyed_event, address, FALSE, NULL);
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
abort(); abort();
} }
static inline uint32_t _sync_add_and_fetch(volatile uint32_t* target,
uint32_t value) {
static_assert(sizeof(*target) == sizeof(long), "");
return (uint32_t) InterlockedAdd((volatile long*) target, (long) value);
}
static inline uint32_t _sync_fetch_and_set(volatile uint32_t* target,
uint32_t value) {
static_assert(sizeof(*target) == sizeof(long), "");
return (uint32_t) InterlockedExchange((volatile long*) target, (long) value);
}
void reflock_ref(reflock_t* reflock) { void reflock_ref(reflock_t* reflock) {
uint32_t state = _sync_add_and_fetch(&reflock->state, _REF); long state = InterlockedAdd(&reflock->state, REFLOCK__REF);
unused_var(state); unused_var(state);
assert((state & _DESTROY_MASK) == 0); /* Overflow or destroyed. */ assert((state & REFLOCK__DESTROY_MASK) == 0); /* Overflow or destroyed. */
} }
void reflock_unref(reflock_t* reflock) { void reflock_unref(reflock_t* reflock) {
uint32_t state = _sync_add_and_fetch(&reflock->state, 0 - _REF); long state = InterlockedAdd(&reflock->state, -REFLOCK__REF);
uint32_t ref_count = state & _REF_MASK; long ref_count = state & REFLOCK__REF_MASK;
uint32_t destroy = state & _DESTROY_MASK; long destroy = state & REFLOCK__DESTROY_MASK;
unused_var(ref_count); unused_var(ref_count);
unused_var(destroy); unused_var(destroy);
if (state == _DESTROY) if (state == REFLOCK__DESTROY)
_signal_event(reflock); reflock__signal_event(reflock);
else else
assert(destroy == 0 || ref_count > 0); assert(destroy == 0 || ref_count > 0);
} }
void reflock_unref_and_destroy(reflock_t* reflock) { void reflock_unref_and_destroy(reflock_t* reflock) {
uint32_t state = _sync_add_and_fetch(&reflock->state, _DESTROY - _REF); long state =
uint32_t ref_count = state & _REF_MASK; InterlockedAdd(&reflock->state, REFLOCK__DESTROY - REFLOCK__REF);
long ref_count = state & REFLOCK__REF_MASK;
assert((state & _DESTROY_MASK) == assert((state & REFLOCK__DESTROY_MASK) ==
_DESTROY); /* Underflow or already destroyed. */ REFLOCK__DESTROY); /* Underflow or already destroyed. */
if (ref_count != 0) if (ref_count != 0)
_await_event(reflock); reflock__await_event(reflock);
state = _sync_fetch_and_set(&reflock->state, _POISON); state = InterlockedExchange(&reflock->state, REFLOCK__POISON);
assert(state == _DESTROY); assert(state == REFLOCK__DESTROY);
} }
static const uint32_t _SOCK_KNOWN_EPOLL_EVENTS = static const uint32_t SOCK__KNOWN_EPOLL_EVENTS =
EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLERR | EPOLLHUP | EPOLLRDNORM | EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLERR | EPOLLHUP | EPOLLRDNORM |
EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND | EPOLLMSG | EPOLLRDHUP; EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND | EPOLLMSG | EPOLLRDHUP;
typedef enum _poll_status { typedef enum sock__poll_status {
_POLL_IDLE = 0, SOCK__POLL_IDLE = 0,
_POLL_PENDING, SOCK__POLL_PENDING,
_POLL_CANCELLED SOCK__POLL_CANCELLED
} _poll_status_t; } sock__poll_status_t;
typedef struct sock_state { typedef struct sock_state {
OVERLAPPED overlapped; OVERLAPPED overlapped;
...@@ -1645,33 +1554,33 @@ typedef struct sock_state { ...@@ -1645,33 +1554,33 @@ typedef struct sock_state {
epoll_data_t user_data; epoll_data_t user_data;
uint32_t user_events; uint32_t user_events;
uint32_t pending_events; uint32_t pending_events;
_poll_status_t poll_status; sock__poll_status_t poll_status;
bool delete_pending; bool delete_pending;
} sock_state_t; } sock_state_t;
static inline sock_state_t* _sock_alloc(void) { static inline sock_state_t* sock__alloc(void) {
sock_state_t* sock_state = malloc(sizeof *sock_state); sock_state_t* sock_state = malloc(sizeof *sock_state);
if (sock_state == NULL) if (sock_state == NULL)
return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
return sock_state; return sock_state;
} }
static inline void _sock_free(sock_state_t* sock_state) { static inline void sock__free(sock_state_t* sock_state) {
free(sock_state); free(sock_state);
} }
static int _sock_cancel_poll(sock_state_t* sock_state) { static int sock__cancel_poll(sock_state_t* sock_state) {
HANDLE driver_handle = HANDLE afd_helper_handle =
(HANDLE)(uintptr_t) poll_group_get_socket(sock_state->poll_group); poll_group_get_afd_helper_handle(sock_state->poll_group);
assert(sock_state->poll_status == _POLL_PENDING); assert(sock_state->poll_status == SOCK__POLL_PENDING);
/* CancelIoEx() may fail with ERROR_NOT_FOUND if the overlapped operation has /* CancelIoEx() may fail with ERROR_NOT_FOUND if the overlapped operation has
* already completed. This is not a problem and we proceed normally. */ * already completed. This is not a problem and we proceed normally. */
if (!CancelIoEx(driver_handle, &sock_state->overlapped) && if (!CancelIoEx(afd_helper_handle, &sock_state->overlapped) &&
GetLastError() != ERROR_NOT_FOUND) GetLastError() != ERROR_NOT_FOUND)
return_map_error(-1); return_map_error(-1);
sock_state->poll_status = _POLL_CANCELLED; sock_state->poll_status = SOCK__POLL_CANCELLED;
sock_state->pending_events = 0; sock_state->pending_events = 0;
return 0; return 0;
} }
...@@ -1692,7 +1601,7 @@ sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) { ...@@ -1692,7 +1601,7 @@ sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) {
if (poll_group == NULL) if (poll_group == NULL)
return NULL; return NULL;
sock_state = _sock_alloc(); sock_state = sock__alloc();
if (sock_state == NULL) if (sock_state == NULL)
goto err1; goto err1;
...@@ -1710,19 +1619,19 @@ sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) { ...@@ -1710,19 +1619,19 @@ sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) {
return sock_state; return sock_state;
err2: err2:
_sock_free(sock_state); sock__free(sock_state);
err1: err1:
poll_group_release(poll_group); poll_group_release(poll_group);
return NULL; return NULL;
} }
static int _sock_delete(port_state_t* port_state, static int sock__delete(port_state_t* port_state,
sock_state_t* sock_state, sock_state_t* sock_state,
bool force) { bool force) {
if (!sock_state->delete_pending) { if (!sock_state->delete_pending) {
if (sock_state->poll_status == _POLL_PENDING) if (sock_state->poll_status == SOCK__POLL_PENDING)
_sock_cancel_poll(sock_state); sock__cancel_poll(sock_state);
port_cancel_socket_update(port_state, sock_state); port_cancel_socket_update(port_state, sock_state);
port_unregister_socket_handle(port_state, sock_state); port_unregister_socket_handle(port_state, sock_state);
...@@ -1733,11 +1642,11 @@ static int _sock_delete(port_state_t* port_state, ...@@ -1733,11 +1642,11 @@ static int _sock_delete(port_state_t* port_state,
/* If the poll request still needs to complete, the sock_state object can't /* If the poll request still needs to complete, the sock_state object can't
* be free()d yet. `sock_feed_event()` or `port_close()` will take care * be free()d yet. `sock_feed_event()` or `port_close()` will take care
* of this later. */ * of this later. */
if (force || sock_state->poll_status == _POLL_IDLE) { if (force || sock_state->poll_status == SOCK__POLL_IDLE) {
/* Free the sock_state now. */ /* Free the sock_state now. */
port_remove_deleted_socket(port_state, sock_state); port_remove_deleted_socket(port_state, sock_state);
poll_group_release(sock_state->poll_group); poll_group_release(sock_state->poll_group);
_sock_free(sock_state); sock__free(sock_state);
} else { } else {
/* Free the socket later. */ /* Free the socket later. */
port_add_deleted_socket(port_state, sock_state); port_add_deleted_socket(port_state, sock_state);
...@@ -1747,11 +1656,11 @@ static int _sock_delete(port_state_t* port_state, ...@@ -1747,11 +1656,11 @@ static int _sock_delete(port_state_t* port_state,
} }
void sock_delete(port_state_t* port_state, sock_state_t* sock_state) { void sock_delete(port_state_t* port_state, sock_state_t* sock_state) {
_sock_delete(port_state, sock_state, false); sock__delete(port_state, sock_state, false);
} }
void sock_force_delete(port_state_t* port_state, sock_state_t* sock_state) { void sock_force_delete(port_state_t* port_state, sock_state_t* sock_state) {
_sock_delete(port_state, sock_state, true); sock__delete(port_state, sock_state, true);
} }
int sock_set_event(port_state_t* port_state, int sock_set_event(port_state_t* port_state,
...@@ -1765,13 +1674,13 @@ int sock_set_event(port_state_t* port_state, ...@@ -1765,13 +1674,13 @@ int sock_set_event(port_state_t* port_state,
sock_state->user_events = events; sock_state->user_events = events;
sock_state->user_data = ev->data; sock_state->user_data = ev->data;
if ((events & _SOCK_KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) != 0) if ((events & SOCK__KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) != 0)
port_request_socket_update(port_state, sock_state); port_request_socket_update(port_state, sock_state);
return 0; return 0;
} }
static inline DWORD _epoll_events_to_afd_events(uint32_t epoll_events) { static inline DWORD sock__epoll_events_to_afd_events(uint32_t epoll_events) {
/* Always monitor for AFD_POLL_LOCAL_CLOSE, which is triggered when the /* Always monitor for AFD_POLL_LOCAL_CLOSE, which is triggered when the
* socket is closed with closesocket() or CloseHandle(). */ * socket is closed with closesocket() or CloseHandle(). */
DWORD afd_events = AFD_POLL_LOCAL_CLOSE; DWORD afd_events = AFD_POLL_LOCAL_CLOSE;
...@@ -1792,7 +1701,7 @@ static inline DWORD _epoll_events_to_afd_events(uint32_t epoll_events) { ...@@ -1792,7 +1701,7 @@ static inline DWORD _epoll_events_to_afd_events(uint32_t epoll_events) {
return afd_events; return afd_events;
} }
static inline uint32_t _afd_events_to_epoll_events(DWORD afd_events) { static inline uint32_t sock__afd_events_to_epoll_events(DWORD afd_events) {
uint32_t epoll_events = 0; uint32_t epoll_events = 0;
if (afd_events & (AFD_POLL_RECEIVE | AFD_POLL_ACCEPT)) if (afd_events & (AFD_POLL_RECEIVE | AFD_POLL_ACCEPT))
...@@ -1814,27 +1723,27 @@ static inline uint32_t _afd_events_to_epoll_events(DWORD afd_events) { ...@@ -1814,27 +1723,27 @@ static inline uint32_t _afd_events_to_epoll_events(DWORD afd_events) {
int sock_update(port_state_t* port_state, sock_state_t* sock_state) { int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
assert(!sock_state->delete_pending); assert(!sock_state->delete_pending);
if ((sock_state->poll_status == _POLL_PENDING) && if ((sock_state->poll_status == SOCK__POLL_PENDING) &&
(sock_state->user_events & _SOCK_KNOWN_EPOLL_EVENTS & (sock_state->user_events & SOCK__KNOWN_EPOLL_EVENTS &
~sock_state->pending_events) == 0) { ~sock_state->pending_events) == 0) {
/* All the events the user is interested in are already being monitored by /* All the events the user is interested in are already being monitored by
* the pending poll operation. It might spuriously complete because of an * the pending poll operation. It might spuriously complete because of an
* event that we're no longer interested in; when that happens we'll submit * event that we're no longer interested in; when that happens we'll submit
* a new poll operation with the updated event mask. */ * a new poll operation with the updated event mask. */
} else if (sock_state->poll_status == _POLL_PENDING) { } else if (sock_state->poll_status == SOCK__POLL_PENDING) {
/* A poll operation is already pending, but it's not monitoring for all the /* A poll operation is already pending, but it's not monitoring for all the
* events that the user is interested in. Therefore, cancel the pending * events that the user is interested in. Therefore, cancel the pending
* poll operation; when we receive it's completion package, a new poll * poll operation; when we receive it's completion package, a new poll
* operation will be submitted with the correct event mask. */ * operation will be submitted with the correct event mask. */
if (_sock_cancel_poll(sock_state) < 0) if (sock__cancel_poll(sock_state) < 0)
return -1; return -1;
} else if (sock_state->poll_status == _POLL_CANCELLED) { } else if (sock_state->poll_status == SOCK__POLL_CANCELLED) {
/* The poll operation has already been cancelled, we're still waiting for /* The poll operation has already been cancelled, we're still waiting for
* it to return. For now, there's nothing that needs to be done. */ * it to return. For now, there's nothing that needs to be done. */
} else if (sock_state->poll_status == _POLL_IDLE) { } else if (sock_state->poll_status == SOCK__POLL_IDLE) {
/* No poll operation is pending; start one. */ /* No poll operation is pending; start one. */
sock_state->poll_info.Exclusive = FALSE; sock_state->poll_info.Exclusive = FALSE;
sock_state->poll_info.NumberOfHandles = 1; sock_state->poll_info.NumberOfHandles = 1;
...@@ -1842,11 +1751,11 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) { ...@@ -1842,11 +1751,11 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
sock_state->poll_info.Handles[0].Handle = (HANDLE) sock_state->base_socket; sock_state->poll_info.Handles[0].Handle = (HANDLE) sock_state->base_socket;
sock_state->poll_info.Handles[0].Status = 0; sock_state->poll_info.Handles[0].Status = 0;
sock_state->poll_info.Handles[0].Events = sock_state->poll_info.Handles[0].Events =
_epoll_events_to_afd_events(sock_state->user_events); sock__epoll_events_to_afd_events(sock_state->user_events);
memset(&sock_state->overlapped, 0, sizeof sock_state->overlapped); memset(&sock_state->overlapped, 0, sizeof sock_state->overlapped);
if (afd_poll(poll_group_get_socket(sock_state->poll_group), if (afd_poll(poll_group_get_afd_helper_handle(sock_state->poll_group),
&sock_state->poll_info, &sock_state->poll_info,
&sock_state->overlapped) < 0) { &sock_state->overlapped) < 0) {
switch (GetLastError()) { switch (GetLastError()) {
...@@ -1855,7 +1764,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) { ...@@ -1855,7 +1764,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
break; break;
case ERROR_INVALID_HANDLE: case ERROR_INVALID_HANDLE:
/* Socket closed; it'll be dropped from the epoll set. */ /* Socket closed; it'll be dropped from the epoll set. */
return _sock_delete(port_state, sock_state, false); return sock__delete(port_state, sock_state, false);
default: default:
/* Other errors are propagated to the caller. */ /* Other errors are propagated to the caller. */
return_map_error(-1); return_map_error(-1);
...@@ -1863,7 +1772,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) { ...@@ -1863,7 +1772,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) {
} }
/* The poll request was successfully submitted. */ /* The poll request was successfully submitted. */
sock_state->poll_status = _POLL_PENDING; sock_state->poll_status = SOCK__POLL_PENDING;
sock_state->pending_events = sock_state->user_events; sock_state->pending_events = sock_state->user_events;
} else { } else {
...@@ -1883,12 +1792,12 @@ int sock_feed_event(port_state_t* port_state, ...@@ -1883,12 +1792,12 @@ int sock_feed_event(port_state_t* port_state,
AFD_POLL_INFO* poll_info = &sock_state->poll_info; AFD_POLL_INFO* poll_info = &sock_state->poll_info;
uint32_t epoll_events = 0; uint32_t epoll_events = 0;
sock_state->poll_status = _POLL_IDLE; sock_state->poll_status = SOCK__POLL_IDLE;
sock_state->pending_events = 0; sock_state->pending_events = 0;
if (sock_state->delete_pending) { if (sock_state->delete_pending) {
/* Socket has been deleted earlier and can now be freed. */ /* Socket has been deleted earlier and can now be freed. */
return _sock_delete(port_state, sock_state, false); return sock__delete(port_state, sock_state, false);
} else if ((NTSTATUS) overlapped->Internal == STATUS_CANCELLED) { } else if ((NTSTATUS) overlapped->Internal == STATUS_CANCELLED) {
/* The poll request was cancelled by CancelIoEx. */ /* The poll request was cancelled by CancelIoEx. */
...@@ -1902,11 +1811,12 @@ int sock_feed_event(port_state_t* port_state, ...@@ -1902,11 +1811,12 @@ int sock_feed_event(port_state_t* port_state,
} else if (poll_info->Handles[0].Events & AFD_POLL_LOCAL_CLOSE) { } else if (poll_info->Handles[0].Events & AFD_POLL_LOCAL_CLOSE) {
/* The poll operation reported that the socket was closed. */ /* The poll operation reported that the socket was closed. */
return _sock_delete(port_state, sock_state, false); return sock__delete(port_state, sock_state, false);
} else { } else {
/* Events related to our socket were reported. */ /* Events related to our socket were reported. */
epoll_events = _afd_events_to_epoll_events(poll_info->Handles[0].Events); epoll_events =
sock__afd_events_to_epoll_events(poll_info->Handles[0].Events);
} }
/* Requeue the socket so a new poll request will be submitted. */ /* Requeue the socket so a new poll request will be submitted. */
...@@ -1965,7 +1875,7 @@ int ts_tree_add(ts_tree_t* ts_tree, ts_tree_node_t* node, uintptr_t key) { ...@@ -1965,7 +1875,7 @@ int ts_tree_add(ts_tree_t* ts_tree, ts_tree_node_t* node, uintptr_t key) {
return r; return r;
} }
static inline ts_tree_node_t* _ts_tree_find_node(ts_tree_t* ts_tree, static inline ts_tree_node_t* ts_tree__find_node(ts_tree_t* ts_tree,
uintptr_t key) { uintptr_t key) {
tree_node_t* tree_node = tree_find(&ts_tree->tree, key); tree_node_t* tree_node = tree_find(&ts_tree->tree, key);
if (tree_node == NULL) if (tree_node == NULL)
...@@ -1979,7 +1889,7 @@ ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree, uintptr_t key) { ...@@ -1979,7 +1889,7 @@ ts_tree_node_t* ts_tree_del_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
AcquireSRWLockExclusive(&ts_tree->lock); AcquireSRWLockExclusive(&ts_tree->lock);
ts_tree_node = _ts_tree_find_node(ts_tree, key); ts_tree_node = ts_tree__find_node(ts_tree, key);
if (ts_tree_node != NULL) { if (ts_tree_node != NULL) {
tree_del(&ts_tree->tree, &ts_tree_node->tree_node); tree_del(&ts_tree->tree, &ts_tree_node->tree_node);
reflock_ref(&ts_tree_node->reflock); reflock_ref(&ts_tree_node->reflock);
...@@ -1995,7 +1905,7 @@ ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree, uintptr_t key) { ...@@ -1995,7 +1905,7 @@ ts_tree_node_t* ts_tree_find_and_ref(ts_tree_t* ts_tree, uintptr_t key) {
AcquireSRWLockShared(&ts_tree->lock); AcquireSRWLockShared(&ts_tree->lock);
ts_tree_node = _ts_tree_find_node(ts_tree, key); ts_tree_node = ts_tree__find_node(ts_tree, key);
if (ts_tree_node != NULL) if (ts_tree_node != NULL)
reflock_ref(&ts_tree_node->reflock); reflock_ref(&ts_tree_node->reflock);
...@@ -2041,11 +1951,11 @@ void tree_node_init(tree_node_t* node) { ...@@ -2041,11 +1951,11 @@ void tree_node_init(tree_node_t* node) {
p->trans->parent = p; \ p->trans->parent = p; \
q->cis = p; q->cis = p;
static inline void _tree_rotate_left(tree_t* tree, tree_node_t* node) { static inline void tree__rotate_left(tree_t* tree, tree_node_t* node) {
TREE__ROTATE(left, right) TREE__ROTATE(left, right)
} }
static inline void _tree_rotate_right(tree_t* tree, tree_node_t* node) { static inline void tree__rotate_right(tree_t* tree, tree_node_t* node) {
TREE__ROTATE(right, left) TREE__ROTATE(right, left)
} }
...@@ -2067,13 +1977,13 @@ static inline void _tree_rotate_right(tree_t* tree, tree_node_t* node) { ...@@ -2067,13 +1977,13 @@ static inline void _tree_rotate_right(tree_t* tree, tree_node_t* node) {
node = grandparent; \ node = grandparent; \
} else { \ } else { \
if (node == parent->trans) { \ if (node == parent->trans) { \
_tree_rotate_##cis(tree, parent); \ tree__rotate_##cis(tree, parent); \
node = parent; \ node = parent; \
parent = node->parent; \ parent = node->parent; \
} \ } \
parent->red = false; \ parent->red = false; \
grandparent->red = true; \ grandparent->red = true; \
_tree_rotate_##trans(tree, grandparent); \ tree__rotate_##trans(tree, grandparent); \
} }
int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) { int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
...@@ -2117,7 +2027,7 @@ int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) { ...@@ -2117,7 +2027,7 @@ int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
if (sibling->red) { \ if (sibling->red) { \
sibling->red = false; \ sibling->red = false; \
parent->red = true; \ parent->red = true; \
_tree_rotate_##cis(tree, parent); \ tree__rotate_##cis(tree, parent); \
sibling = parent->trans; \ sibling = parent->trans; \
} \ } \
if ((sibling->left && sibling->left->red) || \ if ((sibling->left && sibling->left->red) || \
...@@ -2125,12 +2035,12 @@ int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) { ...@@ -2125,12 +2035,12 @@ int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) {
if (!sibling->trans || !sibling->trans->red) { \ if (!sibling->trans || !sibling->trans->red) { \
sibling->cis->red = false; \ sibling->cis->red = false; \
sibling->red = true; \ sibling->red = true; \
_tree_rotate_##trans(tree, sibling); \ tree__rotate_##trans(tree, sibling); \
sibling = parent->trans; \ sibling = parent->trans; \
} \ } \
sibling->red = parent->red; \ sibling->red = parent->red; \
parent->red = sibling->trans->red = false; \ parent->red = sibling->trans->red = false; \
_tree_rotate_##cis(tree, parent); \ tree__rotate_##cis(tree, parent); \
node = tree->root; \ node = tree->root; \
break; \ break; \
} \ } \
...@@ -2230,8 +2140,6 @@ tree_node_t* tree_root(const tree_t* tree) { ...@@ -2230,8 +2140,6 @@ tree_node_t* tree_root(const tree_t* tree) {
#define SIO_BASE_HANDLE 0x48000022 #define SIO_BASE_HANDLE 0x48000022
#endif #endif
#define WS__INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */
int ws_global_init(void) { int ws_global_init(void) {
int r; int r;
WSADATA wsa_data; WSADATA wsa_data;
...@@ -2260,31 +2168,3 @@ SOCKET ws_get_base_socket(SOCKET socket) { ...@@ -2260,31 +2168,3 @@ SOCKET ws_get_base_socket(SOCKET socket) {
return base_socket; return base_socket;
} }
/* Retrieves a copy of the winsock catalog.
* The infos pointer must be released by the caller with free(). */
int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out,
size_t* infos_count_out) {
DWORD buffer_size = WS__INITIAL_CATALOG_BUFFER_SIZE;
int count;
WSAPROTOCOL_INFOW* infos;
for (;;) {
infos = malloc(buffer_size);
if (infos == NULL)
return_set_error(-1, ERROR_NOT_ENOUGH_MEMORY);
count = WSAEnumProtocolsW(NULL, infos, &buffer_size);
if (count == SOCKET_ERROR) {
free(infos);
if (WSAGetLastError() == WSAENOBUFS)
continue; /* Try again with bigger buffer size. */
else
return_map_error(-1);
}
*infos_out = infos;
*infos_count_out = (size_t) count;
return 0;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment