Commit 505e71f7 authored by Kenton Varda's avatar Kenton Varda

Wire up restrictPeers() implementation.

parent 05d0a7ed
......@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
// =======================================================================================
namespace {
class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter {
public:
bool shouldAllow(const struct sockaddr* addr, uint addrlen) override {
return true;
}
};
static DummyFilter DUMMY_FILTER;
} // namespace
struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
public kj::TaskSet::ErrorHandler {
Capability::Client mainInterface;
......@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
context(EzRpcContext::getThreadLocal()),
portPromise(kj::Promise<uint>(port).fork()),
tasks(*this) {
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd), readerOpts);
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER),
readerOpts);
}
void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
......
......@@ -23,6 +23,7 @@
// For Win32 implementation, see async-io-win32.c++.
#include "async-io.h"
#include "async-io-internal.h"
#include "async-unix.h"
#include "debug.h"
#include "thread.h"
......@@ -461,11 +462,12 @@ public:
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
......@@ -480,6 +482,12 @@ public:
result.addr.unixDomain.sun_family = AF_UNIX;
strcpy(result.addr.unixDomain.sun_path, path.cStr());
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -495,6 +503,12 @@ public:
// NULL terminator so that we can safely read it back in toString
memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1);
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -547,7 +561,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
......@@ -569,6 +584,7 @@ public:
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
#endif
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -597,13 +613,18 @@ public:
switch (inet_pton(af, buffer, addrTarget)) {
case 1: {
// success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default:
KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
}
......@@ -616,6 +637,14 @@ public:
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(const _::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic);
}
private:
SocketAddress(): addrlen(0) {
memset(&addr, 0, sizeof(addr));
......@@ -640,8 +669,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
_::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() {
if (thread) thread->detach();
......@@ -654,7 +684,7 @@ public:
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
......@@ -667,8 +697,10 @@ public:
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
}
return read();
}
});
......@@ -677,6 +709,7 @@ public:
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
......@@ -688,7 +721,8 @@ struct SocketAddress::LookupParams {
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
......@@ -773,7 +807,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader));
}
......@@ -781,22 +815,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public:
FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort),
FdConnectionReceiver(UnixEventPort& eventPort, int fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {}
Promise<Own<AsyncIoStream>> accept() override {
int newFd;
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
retry:
#if __linux__ && !__BIONIC__
newFd = ::accept4(fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC);
newFd = ::accept4(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen,
SOCK_NONBLOCK | SOCK_CLOEXEC);
#else
newFd = ::accept(fd, nullptr, nullptr);
newFd = ::accept(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
#endif
if (newFd >= 0) {
if (!filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), addrlen)) {
// Drop disallowed address.
close(newFd);
return accept();
} else {
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
}
} else {
int error = errno;
......@@ -849,13 +894,15 @@ public:
public:
UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer;
};
class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor {
public:
DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort),
DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ |
UnixEventPort::FdObserver::OBSERVE_WRITE) {}
......@@ -883,6 +930,7 @@ public:
public:
LowLevelAsyncIoProvider& lowLevel;
UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer;
};
......@@ -935,11 +983,13 @@ public:
return kj::mv(stream);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
Own<ConnectionReceiver> wrapListenSocketFd(
int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
}
Own<DatagramPort> wrapDatagramSocketFd(int fd, uint flags = 0) override {
return heap<DatagramPortImpl>(*this, eventPort, fd, flags);
Own<DatagramPort> wrapDatagramSocketFd(
int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<DatagramPortImpl>(*this, eventPort, fd, filter, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
......@@ -956,12 +1006,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy);
auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
......@@ -988,7 +1040,7 @@ public:
KJ_SYSCALL(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
......@@ -1011,11 +1063,11 @@ public:
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr()));
return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
}
String toString() override {
......@@ -1029,26 +1081,33 @@ public:
private:
LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) {
LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() {
return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
if (!addrs[0].allowedBy(filter)) {
return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
}, [&lowLevel,&filter,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size()));
return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
......@@ -1060,25 +1119,35 @@ private:
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) {
return SocketAddress::parse(lowLevelCopy, addr, portHint);
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return SocketAddress::parse(lowLevel, addr, portHint, filter);
})).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
}
private:
LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
};
// =======================================================================================
......@@ -1189,10 +1258,16 @@ public:
return receive();
});
} else {
if (!port.filter.shouldAllow(reinterpret_cast<const struct sockaddr*>(msg.msg_name),
msg.msg_namelen)) {
// Ignore message from disallowed source.
return receive();
}
receivedSize = n;
contentTruncated = msg.msg_flags & MSG_TRUNC;
source.emplace(port.lowLevel, msg.msg_name, msg.msg_namelen);
source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen);
ancillaryList.resize(0);
ancillaryTruncated = msg.msg_flags & MSG_CTRUNC;
......@@ -1250,9 +1325,10 @@ private:
bool ancillaryTruncated = false;
struct StoredAddress {
StoredAddress(LowLevelAsyncIoProvider& lowLevel, const void* sockaddr, uint length)
StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter,
const void* sockaddr, uint length)
: raw(sockaddr, length),
abstract(lowLevel, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
abstract(lowLevel, filter, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
SocketAddress raw;
NetworkAddressImpl abstract;
......
......@@ -27,6 +27,7 @@
#define _WIN32_WINNT 0x0600
#include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h"
#include "debug.h"
#include "thread.h"
......@@ -524,11 +525,12 @@ public:
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
......@@ -580,7 +582,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
......@@ -622,13 +625,18 @@ public:
switch (InetPtonA(af, buffer, addrTarget)) {
case 1: {
// success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default:
KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
}
......@@ -641,6 +649,14 @@ public:
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(const _::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic);
}
static SocketAddress getWildcardForFamily(int family) {
SocketAddress result;
switch (family) {
......@@ -680,8 +696,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
_::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() {
if (thread) thread->detach();
......@@ -694,7 +711,7 @@ public:
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
......@@ -707,8 +724,10 @@ public:
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
}
return read();
}
});
......@@ -717,6 +736,7 @@ public:
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
......@@ -728,7 +748,8 @@ struct SocketAddress::LookupParams {
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
......@@ -818,7 +839,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader));
}
......@@ -826,8 +847,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
public:
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort),
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
address(SocketAddress::getLocalAddress(fd)) {
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
......@@ -858,8 +880,9 @@ public:
}
}
return op->onComplete().attach(kj::mv(scratch)).then(mvCapture(result,
[this](Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult) {
return op->onComplete().then(mvCapture(result, mvCapture(scratch,
[this](Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
-> Promise<Own<AsyncIoStream>> {
if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
} else {
......@@ -867,8 +890,17 @@ public:
stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
reinterpret_cast<char*>(&me), sizeof(me));
}
auto addr = reinterpret_cast<struct sockaddr*>(scratch.begin() + 128);
size_t addrlen = addr->sa_family == AF_INET
? sizeof(struct sockaddr_in)
: sizeof(struct sockaddr_in6);
if (filter.shouldAllow(addr, addrlen)) {
return kj::mv(stream);
}));
} else {
return accept();
}
})));
}
uint getPort() override {
......@@ -888,6 +920,7 @@ public:
public:
Win32EventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Own<Win32EventPort::IoObserver> observer;
LPFN_ACCEPTEX acceptEx = nullptr;
SocketAddress address;
......@@ -923,8 +956,9 @@ public:
return kj::mv(result);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(SOCKET fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
Own<ConnectionReceiver> wrapListenSocketFd(
SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
......@@ -941,12 +975,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy);
auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
......@@ -974,7 +1010,7 @@ public:
KJ_WINSOCK(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
......@@ -998,11 +1034,11 @@ public:
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr()));
return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
}
String toString() override {
......@@ -1016,26 +1052,34 @@ public:
private:
LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) {
LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() {
return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
if (!addrs[0].allowedBy(filter)) {
return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,KJ_CPCAP(addrs)](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
}, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
-> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size()));
return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
......@@ -1047,25 +1091,35 @@ private:
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) {
return SocketAddress::parse(lowLevelCopy, addr, portHint);
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return SocketAddress::parse(lowLevel, addr, portHint, filter);
})).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
}
private:
LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
};
// =======================================================================================
......
......@@ -19,15 +19,18 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io-internal.h"
#include "debug.h"
#include "vector.h"
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
......@@ -205,7 +208,8 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
Own<DatagramPort> NetworkAddress::bindDatagramPort() {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(Fd fd, uint flags) {
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment