Commit 505e71f7 authored by Kenton Varda's avatar Kenton Varda

Wire up restrictPeers() implementation.

parent 05d0a7ed
...@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() { ...@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
// ======================================================================================= // =======================================================================================
namespace {
class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter {
public:
bool shouldAllow(const struct sockaddr* addr, uint addrlen) override {
return true;
}
};
static DummyFilter DUMMY_FILTER;
} // namespace
struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>, struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
public kj::TaskSet::ErrorHandler { public kj::TaskSet::ErrorHandler {
Capability::Client mainInterface; Capability::Client mainInterface;
...@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>, ...@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
context(EzRpcContext::getThreadLocal()), context(EzRpcContext::getThreadLocal()),
portPromise(kj::Promise<uint>(port).fork()), portPromise(kj::Promise<uint>(port).fork()),
tasks(*this) { tasks(*this) {
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd), readerOpts); acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER),
readerOpts);
} }
void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) { void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
// For Win32 implementation, see async-io-win32.c++. // For Win32 implementation, see async-io-win32.c++.
#include "async-io.h" #include "async-io.h"
#include "async-io-internal.h"
#include "async-unix.h" #include "async-unix.h"
#include "debug.h" #include "debug.h"
#include "thread.h" #include "thread.h"
...@@ -461,11 +462,12 @@ public: ...@@ -461,11 +462,12 @@ public:
} }
static Promise<Array<SocketAddress>> lookupHost( static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint); LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup. // Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse( static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) { LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`. // TODO(someday): Allow commas in `str`.
SocketAddress result; SocketAddress result;
...@@ -480,6 +482,12 @@ public: ...@@ -480,6 +482,12 @@ public:
result.addr.unixDomain.sun_family = AF_UNIX; result.addr.unixDomain.sun_family = AF_UNIX;
strcpy(result.addr.unixDomain.sun_path, path.cStr()); strcpy(result.addr.unixDomain.sun_path, path.cStr());
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result); array.add(result);
return array.finish(); return array.finish();
...@@ -495,6 +503,12 @@ public: ...@@ -495,6 +503,12 @@ public:
// NULL terminator so that we can safely read it back in toString // NULL terminator so that we can safely read it back in toString
memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1); memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1);
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1; result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result); array.add(result);
return array.finish(); return array.finish();
...@@ -547,7 +561,8 @@ public: ...@@ -547,7 +561,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0); port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') { if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS. // Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint); return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
} }
KJ_REQUIRE(port < 65536, "Port number too large."); KJ_REQUIRE(port < 65536, "Port number too large.");
} else { } else {
...@@ -569,6 +584,7 @@ public: ...@@ -569,6 +584,7 @@ public:
result.addr.inet6.sin6_family = AF_INET6; result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port); result.addr.inet6.sin6_port = htons(port);
#endif #endif
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result); array.add(result);
return array.finish(); return array.finish();
...@@ -597,13 +613,18 @@ public: ...@@ -597,13 +613,18 @@ public:
switch (inet_pton(af, buffer, addrTarget)) { switch (inet_pton(af, buffer, addrTarget)) {
case 1: { case 1: {
// success. // success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result); array.add(result);
return array.finish(); return array.finish();
} }
case 0: case 0:
// It's apparently not a simple address... fall back to DNS. // It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port); return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default: default:
KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart); KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
} }
...@@ -616,6 +637,14 @@ public: ...@@ -616,6 +637,14 @@ public:
return result; return result;
} }
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(const _::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic);
}
private: private:
SocketAddress(): addrlen(0) { SocketAddress(): addrlen(0) {
memset(&addr, 0, sizeof(addr)); memset(&addr, 0, sizeof(addr));
...@@ -640,8 +669,9 @@ class SocketAddress::LookupReader { ...@@ -640,8 +669,9 @@ class SocketAddress::LookupReader {
// getaddrinfo. // getaddrinfo.
public: public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input) LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
: thread(kj::mv(thread)), input(kj::mv(input)) {} _::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() { ~LookupReader() {
if (thread) thread->detach(); if (thread) thread->detach();
...@@ -654,7 +684,7 @@ public: ...@@ -654,7 +684,7 @@ public:
thread = nullptr; thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway. // anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; } KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray(); return addresses.releaseAsArray();
} else { } else {
// getaddrinfo() can return multiple copies of the same address for several reasons. // getaddrinfo() can return multiple copies of the same address for several reasons.
...@@ -667,7 +697,9 @@ public: ...@@ -667,7 +697,9 @@ public:
// //
// So we instead resort to de-duping results. // So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) { if (alreadySeen.insert(current).second) {
addresses.add(current); if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
} }
return read(); return read();
} }
...@@ -677,6 +709,7 @@ public: ...@@ -677,6 +709,7 @@ public:
private: private:
kj::Own<Thread> thread; kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input; kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current; SocketAddress current;
kj::Vector<SocketAddress> addresses; kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen; std::set<SocketAddress> alreadySeen;
...@@ -688,7 +721,8 @@ struct SocketAddress::LookupParams { ...@@ -688,7 +721,8 @@ struct SocketAddress::LookupParams {
}; };
Promise<Array<SocketAddress>> SocketAddress::lookupHost( Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) { LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking. // the only cross-platform DNS API and it is blocking.
// //
...@@ -773,7 +807,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost( ...@@ -773,7 +807,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
} }
})); }));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input)); auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader)); return reader->read().attach(kj::mv(reader));
} }
...@@ -781,22 +815,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost( ...@@ -781,22 +815,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor { class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public: public:
FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags) FdConnectionReceiver(UnixEventPort& eventPort, int fd,
: OwnedFileDescriptor(fd, flags), eventPort(eventPort), LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {} observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {}
Promise<Own<AsyncIoStream>> accept() override { Promise<Own<AsyncIoStream>> accept() override {
int newFd; int newFd;
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
retry: retry:
#if __linux__ && !__BIONIC__ #if __linux__ && !__BIONIC__
newFd = ::accept4(fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC); newFd = ::accept4(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen,
SOCK_NONBLOCK | SOCK_CLOEXEC);
#else #else
newFd = ::accept(fd, nullptr, nullptr); newFd = ::accept(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
#endif #endif
if (newFd >= 0) { if (newFd >= 0) {
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS)); if (!filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), addrlen)) {
// Drop disallowed address.
close(newFd);
return accept();
} else {
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
}
} else { } else {
int error = errno; int error = errno;
...@@ -849,13 +894,15 @@ public: ...@@ -849,13 +894,15 @@ public:
public: public:
UnixEventPort& eventPort; UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer; UnixEventPort::FdObserver observer;
}; };
class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor { class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor {
public: public:
DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd, uint flags) DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd,
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ | observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ |
UnixEventPort::FdObserver::OBSERVE_WRITE) {} UnixEventPort::FdObserver::OBSERVE_WRITE) {}
...@@ -883,6 +930,7 @@ public: ...@@ -883,6 +930,7 @@ public:
public: public:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
UnixEventPort& eventPort; UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer; UnixEventPort::FdObserver observer;
}; };
...@@ -935,11 +983,13 @@ public: ...@@ -935,11 +983,13 @@ public:
return kj::mv(stream); return kj::mv(stream);
})); }));
} }
Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) override { Own<ConnectionReceiver> wrapListenSocketFd(
return heap<FdConnectionReceiver>(eventPort, fd, flags); int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
} }
Own<DatagramPort> wrapDatagramSocketFd(int fd, uint flags = 0) override { Own<DatagramPort> wrapDatagramSocketFd(
return heap<DatagramPortImpl>(*this, eventPort, fd, flags); int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<DatagramPortImpl>(*this, eventPort, fd, filter, flags);
} }
Timer& getTimer() override { return eventPort.getTimer(); } Timer& getTimer() override { return eventPort.getTimer(); }
...@@ -956,12 +1006,14 @@ private: ...@@ -956,12 +1006,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress { class NetworkAddressImpl final: public NetworkAddress {
public: public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs) NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {} LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override { Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr()); auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy); auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy)); return promise.attach(kj::mv(addrsCopy));
} }
...@@ -988,7 +1040,7 @@ public: ...@@ -988,7 +1040,7 @@ public:
KJ_SYSCALL(::listen(fd, SOMAXCONN)); KJ_SYSCALL(::listen(fd, SOMAXCONN));
} }
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS); return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
} }
Own<DatagramPort> bindDatagramPort() override { Own<DatagramPort> bindDatagramPort() override {
...@@ -1011,11 +1063,11 @@ public: ...@@ -1011,11 +1063,11 @@ public:
addrs[0].bind(fd); addrs[0].bind(fd);
} }
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS); return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
} }
Own<NetworkAddress> clone() override { Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr())); return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
} }
String toString() override { String toString() override {
...@@ -1029,26 +1081,33 @@ public: ...@@ -1029,26 +1081,33 @@ public:
private: private:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs; Array<SocketAddress> addrs;
uint counter = 0; uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl( static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) { LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0); KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM); int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() { return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
return lowLevel.wrapConnectingSocketFd( if (!addrs[0].allowedBy(filter)) {
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> { }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along. // Success, pass along.
return kj::mv(stream); return kj::mv(stream);
}, [&lowLevel,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> { }, [&lowLevel,&filter,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
// Connect failed. // Connect failed.
if (addrs.size() > 1) { if (addrs.size() > 1) {
// Try the next address instead. // Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size())); return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else { } else {
// No more addresses to try, so propagate the exception. // No more addresses to try, so propagate the exception.
return kj::mv(exception); return kj::mv(exception);
...@@ -1060,25 +1119,35 @@ private: ...@@ -1060,25 +1119,35 @@ private:
class SocketNetwork final: public Network { class SocketNetwork final: public Network {
public: public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {} explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override { Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel; return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return evalLater(mvCapture(heapString(addr), return SocketAddress::parse(lowLevel, addr, portHint, filter);
[&lowLevelCopy,portHint](String&& addr) { })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return SocketAddress::parse(lowLevelCopy, addr, portHint); return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
}); });
} }
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override { Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len)); array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish())); KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
} }
private: private:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
}; };
// ======================================================================================= // =======================================================================================
...@@ -1189,10 +1258,16 @@ public: ...@@ -1189,10 +1258,16 @@ public:
return receive(); return receive();
}); });
} else { } else {
if (!port.filter.shouldAllow(reinterpret_cast<const struct sockaddr*>(msg.msg_name),
msg.msg_namelen)) {
// Ignore message from disallowed source.
return receive();
}
receivedSize = n; receivedSize = n;
contentTruncated = msg.msg_flags & MSG_TRUNC; contentTruncated = msg.msg_flags & MSG_TRUNC;
source.emplace(port.lowLevel, msg.msg_name, msg.msg_namelen); source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen);
ancillaryList.resize(0); ancillaryList.resize(0);
ancillaryTruncated = msg.msg_flags & MSG_CTRUNC; ancillaryTruncated = msg.msg_flags & MSG_CTRUNC;
...@@ -1250,9 +1325,10 @@ private: ...@@ -1250,9 +1325,10 @@ private:
bool ancillaryTruncated = false; bool ancillaryTruncated = false;
struct StoredAddress { struct StoredAddress {
StoredAddress(LowLevelAsyncIoProvider& lowLevel, const void* sockaddr, uint length) StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter,
const void* sockaddr, uint length)
: raw(sockaddr, length), : raw(sockaddr, length),
abstract(lowLevel, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {} abstract(lowLevel, filter, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
SocketAddress raw; SocketAddress raw;
NetworkAddressImpl abstract; NetworkAddressImpl abstract;
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define _WIN32_WINNT 0x0600 #define _WIN32_WINNT 0x0600
#include "async-io.h" #include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h" #include "async-win32.h"
#include "debug.h" #include "debug.h"
#include "thread.h" #include "thread.h"
...@@ -524,11 +525,12 @@ public: ...@@ -524,11 +525,12 @@ public:
} }
static Promise<Array<SocketAddress>> lookupHost( static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint); LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup. // Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse( static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) { LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`. // TODO(someday): Allow commas in `str`.
SocketAddress result; SocketAddress result;
...@@ -580,7 +582,8 @@ public: ...@@ -580,7 +582,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0); port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') { if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS. // Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint); return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
} }
KJ_REQUIRE(port < 65536, "Port number too large."); KJ_REQUIRE(port < 65536, "Port number too large.");
} else { } else {
...@@ -622,13 +625,18 @@ public: ...@@ -622,13 +625,18 @@ public:
switch (InetPtonA(af, buffer, addrTarget)) { switch (InetPtonA(af, buffer, addrTarget)) {
case 1: { case 1: {
// success. // success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result); array.add(result);
return array.finish(); return array.finish();
} }
case 0: case 0:
// It's apparently not a simple address... fall back to DNS. // It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port); return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default: default:
KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart); KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
} }
...@@ -641,6 +649,14 @@ public: ...@@ -641,6 +649,14 @@ public:
return result; return result;
} }
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(const _::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic);
}
static SocketAddress getWildcardForFamily(int family) { static SocketAddress getWildcardForFamily(int family) {
SocketAddress result; SocketAddress result;
switch (family) { switch (family) {
...@@ -680,8 +696,9 @@ class SocketAddress::LookupReader { ...@@ -680,8 +696,9 @@ class SocketAddress::LookupReader {
// getaddrinfo. // getaddrinfo.
public: public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input) LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
: thread(kj::mv(thread)), input(kj::mv(input)) {} _::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() { ~LookupReader() {
if (thread) thread->detach(); if (thread) thread->detach();
...@@ -694,7 +711,7 @@ public: ...@@ -694,7 +711,7 @@ public:
thread = nullptr; thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check // getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway. // anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; } KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray(); return addresses.releaseAsArray();
} else { } else {
// getaddrinfo() can return multiple copies of the same address for several reasons. // getaddrinfo() can return multiple copies of the same address for several reasons.
...@@ -707,7 +724,9 @@ public: ...@@ -707,7 +724,9 @@ public:
// //
// So we instead resort to de-duping results. // So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) { if (alreadySeen.insert(current).second) {
addresses.add(current); if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
} }
return read(); return read();
} }
...@@ -717,6 +736,7 @@ public: ...@@ -717,6 +736,7 @@ public:
private: private:
kj::Own<Thread> thread; kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input; kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current; SocketAddress current;
kj::Vector<SocketAddress> addresses; kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen; std::set<SocketAddress> alreadySeen;
...@@ -728,7 +748,8 @@ struct SocketAddress::LookupParams { ...@@ -728,7 +748,8 @@ struct SocketAddress::LookupParams {
}; };
Promise<Array<SocketAddress>> SocketAddress::lookupHost( Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) { LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is // This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking. // the only cross-platform DNS API and it is blocking.
// //
...@@ -818,7 +839,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost( ...@@ -818,7 +839,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
} }
})); }));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input)); auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader)); return reader->read().attach(kj::mv(reader));
} }
...@@ -826,8 +847,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost( ...@@ -826,8 +847,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd { class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
public: public:
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, uint flags) FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
: OwnedFd(fd, flags), eventPort(eventPort), LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))), observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
address(SocketAddress::getLocalAddress(fd)) { address(SocketAddress::getLocalAddress(fd)) {
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have // In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
...@@ -858,8 +880,9 @@ public: ...@@ -858,8 +880,9 @@ public:
} }
} }
return op->onComplete().attach(kj::mv(scratch)).then(mvCapture(result, return op->onComplete().then(mvCapture(result, mvCapture(scratch,
[this](Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult) { [this](Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
-> Promise<Own<AsyncIoStream>> {
if (ioResult.errorCode != ERROR_SUCCESS) { if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; } KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
} else { } else {
...@@ -867,8 +890,17 @@ public: ...@@ -867,8 +890,17 @@ public:
stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
reinterpret_cast<char*>(&me), sizeof(me)); reinterpret_cast<char*>(&me), sizeof(me));
} }
return kj::mv(stream);
})); auto addr = reinterpret_cast<struct sockaddr*>(scratch.begin() + 128);
size_t addrlen = addr->sa_family == AF_INET
? sizeof(struct sockaddr_in)
: sizeof(struct sockaddr_in6);
if (filter.shouldAllow(addr, addrlen)) {
return kj::mv(stream);
} else {
return accept();
}
})));
} }
uint getPort() override { uint getPort() override {
...@@ -888,6 +920,7 @@ public: ...@@ -888,6 +920,7 @@ public:
public: public:
Win32EventPort& eventPort; Win32EventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Own<Win32EventPort::IoObserver> observer; Own<Win32EventPort::IoObserver> observer;
LPFN_ACCEPTEX acceptEx = nullptr; LPFN_ACCEPTEX acceptEx = nullptr;
SocketAddress address; SocketAddress address;
...@@ -923,8 +956,9 @@ public: ...@@ -923,8 +956,9 @@ public:
return kj::mv(result); return kj::mv(result);
})); }));
} }
Own<ConnectionReceiver> wrapListenSocketFd(SOCKET fd, uint flags = 0) override { Own<ConnectionReceiver> wrapListenSocketFd(
return heap<FdConnectionReceiver>(eventPort, fd, flags); SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
} }
Timer& getTimer() override { return eventPort.getTimer(); } Timer& getTimer() override { return eventPort.getTimer(); }
...@@ -941,12 +975,14 @@ private: ...@@ -941,12 +975,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress { class NetworkAddressImpl final: public NetworkAddress {
public: public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs) NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {} LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override { Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr()); auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy); auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy)); return promise.attach(kj::mv(addrsCopy));
} }
...@@ -974,7 +1010,7 @@ public: ...@@ -974,7 +1010,7 @@ public:
KJ_WINSOCK(::listen(fd, SOMAXCONN)); KJ_WINSOCK(::listen(fd, SOMAXCONN));
} }
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS); return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
} }
Own<DatagramPort> bindDatagramPort() override { Own<DatagramPort> bindDatagramPort() override {
...@@ -998,11 +1034,11 @@ public: ...@@ -998,11 +1034,11 @@ public:
addrs[0].bind(fd); addrs[0].bind(fd);
} }
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS); return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
} }
Own<NetworkAddress> clone() override { Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr())); return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
} }
String toString() override { String toString() override {
...@@ -1016,26 +1052,34 @@ public: ...@@ -1016,26 +1052,34 @@ public:
private: private:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs; Array<SocketAddress> addrs;
uint counter = 0; uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl( static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) { LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0); KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM); int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() { return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
return lowLevel.wrapConnectingSocketFd( if (!addrs[0].allowedBy(filter)) {
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS); return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> { }).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along. // Success, pass along.
return kj::mv(stream); return kj::mv(stream);
}, [&lowLevel,KJ_CPCAP(addrs)](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> { }, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
-> Promise<Own<AsyncIoStream>> {
// Connect failed. // Connect failed.
if (addrs.size() > 1) { if (addrs.size() > 1) {
// Try the next address instead. // Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size())); return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else { } else {
// No more addresses to try, so propagate the exception. // No more addresses to try, so propagate the exception.
return kj::mv(exception); return kj::mv(exception);
...@@ -1047,25 +1091,35 @@ private: ...@@ -1047,25 +1091,35 @@ private:
class SocketNetwork final: public Network { class SocketNetwork final: public Network {
public: public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {} explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override { Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel; return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return evalLater(mvCapture(heapString(addr), return SocketAddress::parse(lowLevel, addr, portHint, filter);
[&lowLevelCopy,portHint](String&& addr) { })).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return SocketAddress::parse(lowLevelCopy, addr, portHint); return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
}); });
} }
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override { Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1); auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len)); array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish())); KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
} }
private: private:
LowLevelAsyncIoProvider& lowLevel; LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
}; };
// ======================================================================================= // =======================================================================================
......
...@@ -19,15 +19,18 @@ ...@@ -19,15 +19,18 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE. // THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h" #include "async-io.h"
#include "async-io-internal.h" #include "async-io-internal.h"
#include "debug.h" #include "debug.h"
#include "vector.h" #include "vector.h"
#if _WIN32 #if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include <winsock2.h> #include <winsock2.h>
#include <ws2ipdef.h> #include <ws2ipdef.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
...@@ -205,7 +208,8 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len ...@@ -205,7 +208,8 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
Own<DatagramPort> NetworkAddress::bindDatagramPort() { Own<DatagramPort> NetworkAddress::bindDatagramPort() {
KJ_UNIMPLEMENTED("Datagram sockets not implemented."); KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
} }
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(Fd fd, uint flags) { Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented."); KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment