Commit ac6b5d30 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #555 from capnproto/network-filter

Extend kj::Network interface for easy SSRF protection
parents b3dec708 04ff4676
......@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
// =======================================================================================
namespace {
class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter {
public:
bool shouldAllow(const struct sockaddr* addr, uint addrlen) override {
return true;
}
};
static DummyFilter DUMMY_FILTER;
} // namespace
struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
public kj::TaskSet::ErrorHandler {
Capability::Client mainInterface;
......@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
context(EzRpcContext::getThreadLocal()),
portPromise(kj::Promise<uint>(port).fork()),
tasks(*this) {
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd), readerOpts);
acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER),
readerOpts);
}
void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
......
// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef KJ_ASYNC_IO_INTERNAL_H_
#define KJ_ASYNC_IO_INTERNAL_H_
#include "string.h"
#include "vector.h"
#include "async-io.h"
#include <stdint.h>
struct sockaddr;
struct sockaddr_un;
namespace kj {
namespace _ { // private
// =======================================================================================
#if !_WIN32
kj::ArrayPtr<const char> safeUnixPath(const struct sockaddr_un* addr, uint addrlen);
// sockaddr_un::sun_path is not required to have a NUL terminator! Thus to be safe unix address
// paths MUST be read using this function.
#endif
class CidrRange {
public:
CidrRange(StringPtr pattern);
static CidrRange inet4(ArrayPtr<const byte> bits, uint bitCount);
static CidrRange inet6(ArrayPtr<const uint16_t> prefix, ArrayPtr<const uint16_t> suffix,
uint bitCount);
// Zeros are inserted between `prefix` and `suffix` to extend the address to 128 bits.
uint getSpecificity() const { return bitCount; }
bool matches(const struct sockaddr* addr) const;
bool matchesFamily(int family) const;
String toString() const;
private:
int family;
byte bits[16];
uint bitCount; // how many bits in `bits` need to match
CidrRange(int family, ArrayPtr<const byte> bits, uint bitCount);
void zeroIrrelevantBits();
};
class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter {
public:
NetworkFilter();
NetworkFilter(ArrayPtr<const StringPtr> allow, ArrayPtr<const StringPtr> deny,
NetworkFilter& next);
bool shouldAllow(const struct sockaddr* addr, uint addrlen) override;
bool shouldAllowParse(const struct sockaddr* addr, uint addrlen);
private:
Vector<CidrRange> allowCidrs;
Vector<CidrRange> denyCidrs;
bool allowUnix;
bool allowAbstractUnix;
kj::Maybe<NetworkFilter&> next;
};
} // namespace _ (private)
} // namespace kj
#endif // KJ_ASYNC_IO_INTERNAL_H_
......@@ -19,17 +19,27 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io-internal.h"
#include "debug.h"
#include <kj/compat/gtest.h>
#include <sys/types.h>
#if _WIN32
#include <ws2tcpip.h>
#include "windows-sanity.h"
#define inet_pton InetPtonA
#define inet_ntop InetNtopA
#else
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <arpa/inet.h>
#endif
namespace kj {
......@@ -77,12 +87,13 @@ String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint por
return network.parseAddress(text, portHint).wait(waitScope)->toString();
}
bool systemSupportsAddress(StringPtr addr) {
bool systemSupportsAddress(StringPtr addr, StringPtr service = nullptr) {
// Can getaddrinfo() parse this addresses? This is only true if the address family (e.g., ipv6)
// is configured on at least one interface. (The loopback interface usually has both ipv4 and
// ipv6 configured, but not always.)
struct addrinfo* list;
int status = getaddrinfo(addr.cStr(), nullptr, nullptr, &list);
int status = getaddrinfo(
addr.cStr(), service == nullptr ? nullptr : service.cStr(), nullptr, &list);
if (status == 0) {
freeaddrinfo(list);
return true;
......@@ -91,7 +102,6 @@ bool systemSupportsAddress(StringPtr addr) {
}
}
TEST(AsyncIo, AddressParsing) {
auto ioContext = setupAsyncIo();
auto& w = ioContext.waitScope;
......@@ -110,7 +120,7 @@ TEST(AsyncIo, AddressParsing) {
// We can parse services by name...
//
// For some reason, Android and some various Linux distros do not support service names.
if (systemSupportsAddress("1.2.3.4:http")) {
if (systemSupportsAddress("1.2.3.4", "http")) {
EXPECT_EQ("1.2.3.4:80", tryParse(w, network, "1.2.3.4:http", 5678));
EXPECT_EQ("*:80", tryParse(w, network, "*:http", 5678));
} else {
......@@ -122,7 +132,7 @@ TEST(AsyncIo, AddressParsing) {
if (systemSupportsAddress("::")) {
EXPECT_EQ("[::]:123", tryParse(w, network, "0::0", 123));
EXPECT_EQ("[12ab:cd::34]:321", tryParse(w, network, "[12ab:cd:0::0:34]:321", 432));
if (systemSupportsAddress("[12ab:cd::34]:http")) {
if (systemSupportsAddress("12ab:cd::34", "http")) {
EXPECT_EQ("[::]:80", tryParse(w, network, "[::]:http", 5678));
EXPECT_EQ("[12ab:cd::34]:80", tryParse(w, network, "[12ab:cd::34]:http", 5678));
} else {
......@@ -412,5 +422,154 @@ TEST(AsyncIo, AbstractUnixSocket) {
#endif // __linux__
KJ_TEST("CIDR parsing") {
KJ_EXPECT(_::CidrRange("1.2.3.4/16").toString() == "1.2.0.0/16");
KJ_EXPECT(_::CidrRange("1.2.255.4/18").toString() == "1.2.192.0/18");
KJ_EXPECT(_::CidrRange("1234::abcd:ffff:ffff/98").toString() == "1234::abcd:c000:0/98");
KJ_EXPECT(_::CidrRange::inet4({1,2,255,4}, 18).toString() == "1.2.192.0/18");
KJ_EXPECT(_::CidrRange::inet6({0x1234, 0x5678}, {0xabcd, 0xffff, 0xffff}, 98).toString() ==
"1234:5678::abcd:c000:0/98");
union {
struct sockaddr addr;
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
memset(&addr6, 0, sizeof(addr6));
{
addr4.sin_family = AF_INET;
addr4.sin_addr.s_addr = htonl(0x0102dfff);
KJ_EXPECT(_::CidrRange("1.2.255.255/18").matches(&addr));
KJ_EXPECT(!_::CidrRange("1.2.255.255/19").matches(&addr));
KJ_EXPECT(_::CidrRange("1.2.0.0/16").matches(&addr));
KJ_EXPECT(!_::CidrRange("1.3.0.0/16").matches(&addr));
KJ_EXPECT(_::CidrRange("1.2.223.255/32").matches(&addr));
KJ_EXPECT(_::CidrRange("0.0.0.0/0").matches(&addr));
KJ_EXPECT(!_::CidrRange("::/0").matches(&addr));
}
{
addr4.sin_family = AF_INET6;
byte bytes[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
memcpy(addr6.sin6_addr.s6_addr, bytes, 16);
KJ_EXPECT(_::CidrRange("0102:03ff::/24").matches(&addr));
KJ_EXPECT(!_::CidrRange("0102:02ff::/24").matches(&addr));
KJ_EXPECT(_::CidrRange("0102:02ff::/23").matches(&addr));
KJ_EXPECT(_::CidrRange("0102:0304:0506:0708:090a:0b0c:0d0e:0f10/128").matches(&addr));
KJ_EXPECT(_::CidrRange("::/0").matches(&addr));
KJ_EXPECT(!_::CidrRange("0.0.0.0/0").matches(&addr));
}
{
addr4.sin_family = AF_INET6;
inet_pton(AF_INET6, "::ffff:1.2.223.255", &addr6.sin6_addr);
KJ_EXPECT(_::CidrRange("1.2.255.255/18").matches(&addr));
KJ_EXPECT(!_::CidrRange("1.2.255.255/19").matches(&addr));
KJ_EXPECT(_::CidrRange("1.2.0.0/16").matches(&addr));
KJ_EXPECT(!_::CidrRange("1.3.0.0/16").matches(&addr));
KJ_EXPECT(_::CidrRange("1.2.223.255/32").matches(&addr));
KJ_EXPECT(_::CidrRange("0.0.0.0/0").matches(&addr));
KJ_EXPECT(_::CidrRange("::/0").matches(&addr));
}
}
bool allowed4(_::NetworkFilter& filter, StringPtr addrStr) {
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
inet_pton(AF_INET, addrStr.cStr(), &addr.sin_addr);
return filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr));
}
bool allowed6(_::NetworkFilter& filter, StringPtr addrStr) {
struct sockaddr_in6 addr;
memset(&addr, 0, sizeof(addr));
addr.sin6_family = AF_INET6;
inet_pton(AF_INET6, addrStr.cStr(), &addr.sin6_addr);
return filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr));
}
KJ_TEST("NetworkFilter") {
_::NetworkFilter base;
KJ_EXPECT(allowed4(base, "8.8.8.8"));
KJ_EXPECT(!allowed4(base, "240.1.2.3"));
{
_::NetworkFilter filter({"public"}, {}, base);
KJ_EXPECT(allowed4(filter, "8.8.8.8"));
KJ_EXPECT(!allowed4(filter, "240.1.2.3"));
KJ_EXPECT(!allowed4(filter, "192.168.0.1"));
KJ_EXPECT(!allowed4(filter, "10.1.2.3"));
KJ_EXPECT(!allowed4(filter, "127.0.0.1"));
KJ_EXPECT(!allowed4(filter, "0.0.0.0"));
KJ_EXPECT(allowed6(filter, "2400:cb00:2048:1::c629:d7a2"));
KJ_EXPECT(!allowed6(filter, "fc00::1234"));
KJ_EXPECT(!allowed6(filter, "::1"));
KJ_EXPECT(!allowed6(filter, "::"));
}
{
_::NetworkFilter filter({"private"}, {"local"}, base);
KJ_EXPECT(!allowed4(filter, "8.8.8.8"));
KJ_EXPECT(!allowed4(filter, "240.1.2.3"));
KJ_EXPECT(allowed4(filter, "192.168.0.1"));
KJ_EXPECT(allowed4(filter, "10.1.2.3"));
KJ_EXPECT(!allowed4(filter, "127.0.0.1"));
KJ_EXPECT(!allowed4(filter, "0.0.0.0"));
KJ_EXPECT(!allowed6(filter, "2400:cb00:2048:1::c629:d7a2"));
KJ_EXPECT(allowed6(filter, "fc00::1234"));
KJ_EXPECT(!allowed6(filter, "::1"));
KJ_EXPECT(!allowed6(filter, "::"));
}
{
_::NetworkFilter filter({"1.0.0.0/8", "1.2.3.0/24"}, {"1.2.0.0/16", "1.2.3.4/32"}, base);
KJ_EXPECT(!allowed4(filter, "8.8.8.8"));
KJ_EXPECT(!allowed4(filter, "240.1.2.3"));
KJ_EXPECT(allowed4(filter, "1.0.0.1"));
KJ_EXPECT(!allowed4(filter, "1.2.2.1"));
KJ_EXPECT(allowed4(filter, "1.2.3.1"));
KJ_EXPECT(!allowed4(filter, "1.2.3.4"));
}
}
KJ_TEST("Network::restrictPeers()") {
auto ioContext = setupAsyncIo();
auto& w = ioContext.waitScope;
auto& network = ioContext.provider->getNetwork();
auto restrictedNetwork = network.restrictPeers({"public"});
KJ_EXPECT(tryParse(w, *restrictedNetwork, "8.8.8.8") == "8.8.8.8:0");
#if !_WIN32
KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo"));
#endif
auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w);
auto listener = addr->listen();
auto acceptTask = listener->accept()
.then([](kj::Own<kj::AsyncIoStream>) {
KJ_FAIL_EXPECT("should not have received connection");
}).eagerlyEvaluate(nullptr);
KJ_EXPECT_THROW_MESSAGE("restrictPeers", addr->connect().wait(w));
// We can connect to the listener but the connection will be immediately closed.
auto addr2 = network.parseAddress("127.0.0.1", listener->getPort()).wait(w);
auto conn = addr2->connect().wait(w);
KJ_EXPECT(conn->readAllText().wait(w) == "");
}
} // namespace
} // namespace kj
......@@ -23,6 +23,7 @@
// For Win32 implementation, see async-io-win32.c++.
#include "async-io.h"
#include "async-io-internal.h"
#include "async-unix.h"
#include "debug.h"
#include "thread.h"
......@@ -449,10 +450,11 @@ public:
return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
}
case AF_UNIX: {
if (addr.unixDomain.sun_path[0] == '\0') {
return str("unix-abstract:", addr.unixDomain.sun_path + 1);
auto path = _::safeUnixPath(&addr.unixDomain, addrlen);
if (path.size() > 0 && path[0] == '\0') {
return str("unix-abstract:", path.slice(1, path.size()));
} else {
return str("unix:", addr.unixDomain.sun_path);
return str("unix:", path);
}
}
default:
......@@ -461,11 +463,12 @@ public:
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
......@@ -480,6 +483,12 @@ public:
result.addr.unixDomain.sun_family = AF_UNIX;
strcpy(result.addr.unixDomain.sun_path, path.cStr());
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -495,6 +504,12 @@ public:
// NULL terminator so that we can safely read it back in toString
memcpy(result.addr.unixDomain.sun_path + 1, path.cStr(), path.size() + 1);
result.addrlen = offsetof(struct sockaddr_un, sun_path) + path.size() + 1;
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("abstract unix sockets blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -547,7 +562,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
......@@ -569,6 +585,7 @@ public:
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
#endif
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
......@@ -597,13 +614,18 @@ public:
switch (inet_pton(af, buffer, addrTarget)) {
case 1: {
// success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default:
KJ_FAIL_SYSCALL("inet_pton", errno, af, addrPart);
}
......@@ -616,6 +638,14 @@ public:
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(_::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic, addrlen);
}
private:
SocketAddress(): addrlen(0) {
memset(&addr, 0, sizeof(addr));
......@@ -640,8 +670,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
_::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() {
if (thread) thread->detach();
......@@ -654,7 +685,7 @@ public:
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
......@@ -667,7 +698,9 @@ public:
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
addresses.add(current);
if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
}
return read();
}
......@@ -677,6 +710,7 @@ public:
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
......@@ -688,7 +722,8 @@ struct SocketAddress::LookupParams {
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
......@@ -773,7 +808,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader));
}
......@@ -781,22 +816,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFileDescriptor {
public:
FdConnectionReceiver(UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort),
FdConnectionReceiver(UnixEventPort& eventPort, int fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ) {}
Promise<Own<AsyncIoStream>> accept() override {
int newFd;
struct sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
retry:
#if __linux__ && !__BIONIC__
newFd = ::accept4(fd, nullptr, nullptr, SOCK_NONBLOCK | SOCK_CLOEXEC);
newFd = ::accept4(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen,
SOCK_NONBLOCK | SOCK_CLOEXEC);
#else
newFd = ::accept(fd, nullptr, nullptr);
newFd = ::accept(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
#endif
if (newFd >= 0) {
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
if (!filter.shouldAllow(reinterpret_cast<struct sockaddr*>(&addr), addrlen)) {
// Drop disallowed address.
close(newFd);
return accept();
} else {
return Own<AsyncIoStream>(heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS));
}
} else {
int error = errno;
......@@ -849,13 +895,15 @@ public:
public:
UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer;
};
class DatagramPortImpl final: public DatagramPort, public OwnedFileDescriptor {
public:
DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd, uint flags)
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort),
DatagramPortImpl(LowLevelAsyncIoProvider& lowLevel, UnixEventPort& eventPort, int fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFileDescriptor(fd, flags), lowLevel(lowLevel), eventPort(eventPort), filter(filter),
observer(eventPort, fd, UnixEventPort::FdObserver::OBSERVE_READ |
UnixEventPort::FdObserver::OBSERVE_WRITE) {}
......@@ -883,6 +931,7 @@ public:
public:
LowLevelAsyncIoProvider& lowLevel;
UnixEventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
UnixEventPort::FdObserver observer;
};
......@@ -935,11 +984,13 @@ public:
return kj::mv(stream);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(int fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
Own<ConnectionReceiver> wrapListenSocketFd(
int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
}
Own<DatagramPort> wrapDatagramSocketFd(int fd, uint flags = 0) override {
return heap<DatagramPortImpl>(*this, eventPort, fd, flags);
Own<DatagramPort> wrapDatagramSocketFd(
int fd, NetworkFilter& filter, uint flags = 0) override {
return heap<DatagramPortImpl>(*this, eventPort, fd, filter, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
......@@ -956,12 +1007,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy);
auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
......@@ -988,7 +1041,7 @@ public:
KJ_SYSCALL(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
......@@ -1011,11 +1064,11 @@ public:
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr()));
return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
}
String toString() override {
......@@ -1029,26 +1082,33 @@ public:
private:
LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) {
LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
if (!addrs[0].allowedBy(filter)) {
return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
}, [&lowLevel,&filter,addrs](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size()));
return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
......@@ -1060,25 +1120,35 @@ private:
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) {
return SocketAddress::parse(lowLevelCopy, addr, portHint);
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return SocketAddress::parse(lowLevel, addr, portHint, filter);
})).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
}
private:
LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
};
// =======================================================================================
......@@ -1189,10 +1259,16 @@ public:
return receive();
});
} else {
if (!port.filter.shouldAllow(reinterpret_cast<const struct sockaddr*>(msg.msg_name),
msg.msg_namelen)) {
// Ignore message from disallowed source.
return receive();
}
receivedSize = n;
contentTruncated = msg.msg_flags & MSG_TRUNC;
source.emplace(port.lowLevel, msg.msg_name, msg.msg_namelen);
source.emplace(port.lowLevel, port.filter, msg.msg_name, msg.msg_namelen);
ancillaryList.resize(0);
ancillaryTruncated = msg.msg_flags & MSG_CTRUNC;
......@@ -1250,9 +1326,10 @@ private:
bool ancillaryTruncated = false;
struct StoredAddress {
StoredAddress(LowLevelAsyncIoProvider& lowLevel, const void* sockaddr, uint length)
StoredAddress(LowLevelAsyncIoProvider& lowLevel, LowLevelAsyncIoProvider::NetworkFilter& filter,
const void* sockaddr, uint length)
: raw(sockaddr, length),
abstract(lowLevel, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
abstract(lowLevel, filter, Array<SocketAddress>(&raw, 1, NullArrayDisposer::instance)) {}
SocketAddress raw;
NetworkAddressImpl abstract;
......
......@@ -27,6 +27,7 @@
#define _WIN32_WINNT 0x0600
#include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h"
#include "debug.h"
#include "thread.h"
......@@ -524,11 +525,12 @@ public:
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint);
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
......@@ -580,7 +582,8 @@ public:
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint);
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
......@@ -622,25 +625,45 @@ public:
switch (InetPtonA(af, buffer, addrTarget)) {
case 1: {
// success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port);
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
default:
KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
}
}
static SocketAddress getLocalAddress(int sockfd) {
static SocketAddress getLocalAddress(SOCKET sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
static SocketAddress getPeerAddress(SOCKET sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(_::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic, addrlen);
}
static SocketAddress getWildcardForFamily(int family) {
SocketAddress result;
switch (family) {
......@@ -680,8 +703,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input)
: thread(kj::mv(thread)), input(kj::mv(input)) {}
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
_::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() {
if (thread) thread->detach();
......@@ -694,7 +718,7 @@ public:
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no addresses.") { break; }
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
......@@ -707,7 +731,9 @@ public:
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
addresses.add(current);
if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
}
return read();
}
......@@ -717,6 +743,7 @@ public:
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
......@@ -728,7 +755,8 @@ struct SocketAddress::LookupParams {
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint) {
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
......@@ -818,7 +846,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader));
}
......@@ -826,8 +854,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
public:
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort),
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
address(SocketAddress::getLocalAddress(fd)) {
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
......@@ -858,8 +887,10 @@ public:
}
}
return op->onComplete().attach(kj::mv(scratch)).then(mvCapture(result,
[this](Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult) {
return op->onComplete().then(mvCapture(result, mvCapture(scratch,
[this,newFd]
(Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
-> Promise<Own<AsyncIoStream>> {
if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
} else {
......@@ -867,8 +898,19 @@ public:
stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
reinterpret_cast<char*>(&me), sizeof(me));
}
return kj::mv(stream);
}));
// Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
// named `scratch`). However, the format in which it writes these is undocumented, and
// doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
// why they require the buffer to have space for it in the first place. We'll need to call
// getpeername() to get the address.
auto addr = SocketAddress::getPeerAddress(newFd);
if (addr.allowedBy(filter)) {
return kj::mv(stream);
} else {
return accept();
}
})));
}
uint getPort() override {
......@@ -888,6 +930,7 @@ public:
public:
Win32EventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Own<Win32EventPort::IoObserver> observer;
LPFN_ACCEPTEX acceptEx = nullptr;
SocketAddress address;
......@@ -923,8 +966,9 @@ public:
return kj::mv(result);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(SOCKET fd, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, flags);
Own<ConnectionReceiver> wrapListenSocketFd(
SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
......@@ -941,12 +985,14 @@ private:
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel, Array<SocketAddress> addrs)
: lowLevel(lowLevel), addrs(kj::mv(addrs)) {}
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, addrsCopy);
auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
......@@ -974,7 +1020,7 @@ public:
KJ_WINSOCK(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
......@@ -998,11 +1044,11 @@ public:
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, NEW_FD_FLAGS);
return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, kj::heapArray(addrs.asPtr()));
return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
}
String toString() override {
......@@ -1016,26 +1062,34 @@ public:
private:
LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel, ArrayPtr<SocketAddress> addrs) {
LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
if (!addrs[0].allowedBy(filter)) {
return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,KJ_CPCAP(addrs)](Exception&& exception) mutable -> Promise<Own<AsyncIoStream>> {
}, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
-> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, addrs.slice(1, addrs.size()));
return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
......@@ -1047,25 +1101,35 @@ private:
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
auto& lowLevelCopy = lowLevel;
return evalLater(mvCapture(heapString(addr),
[&lowLevelCopy,portHint](String&& addr) {
return SocketAddress::parse(lowLevelCopy, addr, portHint);
})).then([&lowLevelCopy](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevelCopy, kj::mv(addresses));
return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return SocketAddress::parse(lowLevel, addr, portHint, filter);
})).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, array.finish()));
KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
}
private:
LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
};
// =======================================================================================
......
......@@ -19,10 +19,30 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io-internal.h"
#include "debug.h"
#include "vector.h"
#if _WIN32
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include "windows-sanity.h"
#define inet_pton InetPtonA
#define inet_ntop InetNtopA
#else
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/un.h>
#endif
namespace kj {
Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) {
......@@ -188,8 +208,351 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
Own<DatagramPort> NetworkAddress::bindDatagramPort() {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(Fd fd, uint flags) {
Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(
Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) {
KJ_UNIMPLEMENTED("Datagram sockets not implemented.");
}
// =======================================================================================
namespace _ { // private
#if !_WIN32
kj::ArrayPtr<const char> safeUnixPath(const struct sockaddr_un* addr, uint addrlen) {
KJ_REQUIRE(addr->sun_family == AF_UNIX, "not a unix address");
KJ_REQUIRE(addrlen >= offsetof(sockaddr_un, sun_path), "invalid unix address");
size_t maxPathlen = addrlen - offsetof(sockaddr_un, sun_path);
size_t pathlen;
if (maxPathlen > 0 && addr->sun_path[0] == '\0') {
// Linux "abstract" unix address
pathlen = strnlen(addr->sun_path + 1, maxPathlen - 1) + 1;
} else {
pathlen = strnlen(addr->sun_path, maxPathlen);
}
return kj::arrayPtr(addr->sun_path, pathlen);
}
#endif // !_WIN32
CidrRange::CidrRange(StringPtr pattern) {
size_t slashPos = KJ_REQUIRE_NONNULL(pattern.findFirst('/'), "invalid CIDR", pattern);
bitCount = pattern.slice(slashPos + 1).parseAs<uint>();
KJ_STACK_ARRAY(char, addr, slashPos + 1, 128, 128);
memcpy(addr.begin(), pattern.begin(), slashPos);
addr[slashPos] = '\0';
if (pattern.findFirst(':') == nullptr) {
family = AF_INET;
KJ_REQUIRE(bitCount <= 32, "invalid CIDR", pattern);
} else {
family = AF_INET6;
KJ_REQUIRE(bitCount <= 128, "invalid CIDR", pattern);
}
KJ_ASSERT(inet_pton(family, addr.begin(), bits) > 0, "invalid CIDR", pattern);
zeroIrrelevantBits();
}
CidrRange::CidrRange(int family, ArrayPtr<const byte> bits, uint bitCount)
: family(family), bitCount(bitCount) {
if (family == AF_INET) {
KJ_REQUIRE(bitCount <= 32);
} else {
KJ_REQUIRE(bitCount <= 128);
}
KJ_REQUIRE(bits.size() * 8 >= bitCount);
size_t byteCount = (bitCount + 7) / 8;
memcpy(this->bits, bits.begin(), byteCount);
memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount);
zeroIrrelevantBits();
}
CidrRange CidrRange::inet4(ArrayPtr<const byte> bits, uint bitCount) {
return CidrRange(AF_INET, bits, bitCount);
}
CidrRange CidrRange::inet6(
ArrayPtr<const uint16_t> prefix, ArrayPtr<const uint16_t> suffix,
uint bitCount) {
KJ_REQUIRE(prefix.size() + suffix.size() <= 8);
byte bits[16] = { 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, };
for (size_t i: kj::indices(prefix)) {
bits[i * 2] = prefix[i] >> 8;
bits[i * 2 + 1] = prefix[i] & 0xff;
}
byte* suffixBits = bits + (16 - suffix.size() * 2);
for (size_t i: kj::indices(suffix)) {
suffixBits[i * 2] = suffix[i] >> 8;
suffixBits[i * 2 + 1] = suffix[i] & 0xff;
}
return CidrRange(AF_INET6, bits, bitCount);
}
bool CidrRange::matches(const struct sockaddr* addr) const {
const byte* otherBits;
switch (family) {
case AF_INET:
if (addr->sa_family == AF_INET6) {
otherBits = reinterpret_cast<const struct sockaddr_in6*>(addr)->sin6_addr.s6_addr;
static constexpr byte V6MAPPED[12] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff };
if (memcmp(otherBits, V6MAPPED, sizeof(V6MAPPED)) == 0) {
// We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning
// it's equivalent to an ipv4 address. Try to match against the ipv4 part.
otherBits = otherBits + sizeof(V6MAPPED);
} else {
return false;
}
} else if (addr->sa_family == AF_INET) {
otherBits = reinterpret_cast<const byte*>(
&reinterpret_cast<const struct sockaddr_in*>(addr)->sin_addr.s_addr);
} else {
return false;
}
break;
case AF_INET6:
if (addr->sa_family != AF_INET6) return false;
otherBits = reinterpret_cast<const struct sockaddr_in6*>(addr)->sin6_addr.s6_addr;
break;
default:
KJ_UNREACHABLE;
}
if (memcmp(bits, otherBits, bitCount / 8) != 0) return false;
return bitCount == 128 ||
bits[bitCount / 8] == (otherBits[bitCount / 8] & (0xff00 >> (bitCount % 8)));
}
bool CidrRange::matchesFamily(int family) const {
switch (family) {
case AF_INET:
return this->family == AF_INET;
case AF_INET6:
// Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range.
return true;
default:
return false;
}
}
String CidrRange::toString() const {
char result[128];
KJ_ASSERT(inet_ntop(family, (void*)bits, result, sizeof(result)) == result);
return kj::str(result, '/', bitCount);
}
void CidrRange::zeroIrrelevantBits() {
// Mask out insignificant bits of partial byte.
if (bitCount < 128) {
bits[bitCount / 8] &= 0xff00 >> (bitCount % 8);
// Zero the remaining bytes.
size_t n = bitCount / 8 + 1;
memset(bits + n, 0, sizeof(bits) - n);
}
}
// -----------------------------------------------------------------------------
ArrayPtr<const CidrRange> localCidrs() {
static const CidrRange result[] = {
// localhost
"127.0.0.0/8"_kj,
"::1/128"_kj,
// Trying to *connect* to 0.0.0.0 on many systems is equivalent to connecting to localhost.
// (wat)
"0.0.0.0/32"_kj,
"::/128"_kj,
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return kj::arrayPtr(result, kj::size(result));
}
ArrayPtr<const CidrRange> privateCidrs() {
static const CidrRange result[] = {
"10.0.0.0/8"_kj, // RFC1918 reserved for internal network
"100.64.0.0/10"_kj, // RFC6598 "shared address space" for carrier-grade NAT
"169.254.0.0/16"_kj, // RFC3927 "link local" (auto-configured LAN in absence of DHCP)
"172.16.0.0/12"_kj, // RFC1918 reserved for internal network
"192.168.0.0/16"_kj, // RFC1918 reserved for internal network
"fc00::/7"_kj, // RFC4193 unique private network
"fe80::/10"_kj, // RFC4291 "link local" (auto-configured LAN in absence of DHCP)
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return kj::arrayPtr(result, kj::size(result));
}
ArrayPtr<const CidrRange> reservedCidrs() {
static const CidrRange result[] = {
"192.0.0.0/24"_kj, // RFC6890 reserved for special protocols
"224.0.0.0/4"_kj, // RFC1112 multicast
"240.0.0.0/4"_kj, // RFC1112 multicast / reserved for future use
"255.255.255.255/32"_kj, // RFC0919 broadcast address
"2001::/23"_kj, // RFC2928 reserved for special protocols
"ff00::/8"_kj, // RFC4291 multicast
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return kj::arrayPtr(result, kj::size(result));
}
ArrayPtr<const CidrRange> exampleAddresses() {
static const CidrRange result[] = {
"192.0.2.0/24"_kj, // RFC5737 "example address" block 1 -- like example.com for IPs
"198.51.100.0/24"_kj, // RFC5737 "example address" block 2 -- like example.com for IPs
"203.0.113.0/24"_kj, // RFC5737 "example address" block 3 -- like example.com for IPs
"2001:db8::/32"_kj, // RFC3849 "example address" block -- like example.com for IPs
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return kj::arrayPtr(result, kj::size(result));
}
NetworkFilter::NetworkFilter()
: allowUnix(true), allowAbstractUnix(true) {
allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0));
allowCidrs.add(CidrRange::inet6({}, {}, 0));
denyCidrs.addAll(reservedCidrs());
}
NetworkFilter::NetworkFilter(ArrayPtr<const StringPtr> allow, ArrayPtr<const StringPtr> deny,
NetworkFilter& next)
: allowUnix(false), allowAbstractUnix(false), next(next) {
for (auto rule: allow) {
if (rule == "local") {
allowCidrs.addAll(localCidrs());
} else if (rule == "network") {
allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0));
allowCidrs.add(CidrRange::inet6({}, {}, 0));
denyCidrs.addAll(localCidrs());
} else if (rule == "private") {
allowCidrs.addAll(privateCidrs());
allowCidrs.addAll(localCidrs());
} else if (rule == "public") {
allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0));
allowCidrs.add(CidrRange::inet6({}, {}, 0));
denyCidrs.addAll(privateCidrs());
denyCidrs.addAll(localCidrs());
} else if (rule == "unix") {
allowUnix = true;
} else if (rule == "unix-abstract") {
allowAbstractUnix = true;
} else {
allowCidrs.add(CidrRange(rule));
}
}
for (auto rule: deny) {
if (rule == "local") {
denyCidrs.addAll(localCidrs());
} else if (rule == "network") {
KJ_FAIL_REQUIRE("don't deny 'network', allow 'local' instead");
} else if (rule == "private") {
denyCidrs.addAll(privateCidrs());
} else if (rule == "public") {
// Tricky: What if we allow 'network' and deny 'public'?
KJ_FAIL_REQUIRE("don't deny 'public', allow 'private' instead");
} else if (rule == "unix") {
allowUnix = false;
} else if (rule == "unix-abstract") {
allowAbstractUnix = false;
} else {
denyCidrs.add(CidrRange(rule));
}
}
}
bool NetworkFilter::shouldAllow(const struct sockaddr* addr, uint addrlen) {
KJ_REQUIRE(addrlen >= sizeof(addr->sa_family));
#if !_WIN32
if (addr->sa_family == AF_UNIX) {
auto path = safeUnixPath(reinterpret_cast<const struct sockaddr_un*>(addr), addrlen);
if (path.size() > 0 && path[0] == '\0') {
return allowAbstractUnix;
} else {
return allowUnix;
}
}
#endif
bool allowed = false;
uint allowSpecificity = 0;
for (auto& cidr: allowCidrs) {
if (cidr.matches(addr)) {
allowSpecificity = kj::max(allowSpecificity, cidr.getSpecificity());
allowed = true;
}
}
if (!allowed) return false;
for (auto& cidr: denyCidrs) {
if (cidr.matches(addr)) {
if (cidr.getSpecificity() >= allowSpecificity) return false;
}
}
KJ_IF_MAYBE(n, next) {
return n->shouldAllow(addr, addrlen);
} else {
return true;
}
}
bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) {
bool matched = false;
#if !_WIN32
if (addr->sa_family == AF_UNIX) {
auto path = safeUnixPath(reinterpret_cast<const struct sockaddr_un*>(addr), addrlen);
if (path.size() > 0 && path[0] == '\0') {
if (allowAbstractUnix) matched = true;
} else {
if (allowUnix) matched = true;
}
} else {
#endif
for (auto& cidr: allowCidrs) {
if (cidr.matchesFamily(addr->sa_family)) {
matched = true;
}
}
#if !_WIN32
}
#endif
if (matched) {
KJ_IF_MAYBE(n, next) {
return n->shouldAllowParse(addr, addrlen);
} else {
return true;
}
} else {
// No allow rule matches this address family, so don't even allow parsing it.
return false;
}
}
} // namespace _ (private)
} // namespace kj
......@@ -319,6 +319,67 @@ public:
virtual Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) = 0;
// Construct a network address from a legacy struct sockaddr.
virtual Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) KJ_WARN_UNUSED_RESULT = 0;
// Constructs a new Network instance wrapping this one which restricts which peer addresses are
// permitted (both for outgoing and incoming connections).
//
// Communication will be allowed only with peers whose addresses match one of the patterns
// specified in the `allow` array. If a `deny` array is specified, then any address which matches
// a pattern in `deny` and *does not* match any more-specific pattern in `allow` will also be
// denied.
//
// The syntax of address patterns depends on the network, except that three special patterns are
// defined for all networks:
// - "private": Matches network addresses that are reserved by standards for private networks,
// such as "10.0.0.0/8" or "192.168.0.0/16". This is a superset of "local".
// - "public": Opposite of "private".
// - "local": Matches network addresses that are defined by standards to only be accessible from
// the local machine, such as "127.0.0.0/8" or Unix domain addresses.
// - "network": Opposite of "local".
//
// For the standard KJ network implementation, the following patterns are also recognized:
// - Network blocks specified in CIDR notation (ipv4 and ipv6), such as "192.0.2.0/24" or
// "2001:db8::/32".
// - "unix" to match all Unix domain addresses. (In the future, we may support specifying a
// glob.)
// - "unix-abstract" to match Linux's "abstract unix domain" addresses. (In the future, we may
// support specifying a glob.)
//
// Network restrictions apply *after* DNS resolution (otherwise they'd be useless).
//
// It is legal to parseAddress() a restricted address. An exception won't be thrown until
// connect() is called.
//
// It's possible to listen() on a restricted address. However, connections will only be accepted
// from non-restricted addresses; others will be dropped. If a particular listen address has no
// valid peers (e.g. because it's a unix socket address and unix sockets are not allowed) then
// listen() may throw (or may simply never receive any connections).
//
// Examples:
//
// auto restricted = network->restrictPeers({"public"});
//
// Allows connections only to/from public internet addresses. Use this when connecting to an
// address specified by a third party that is not trusted and is not themselves already on your
// private network.
//
// auto restricted = network->restrictPeers({"private"});
//
// Allows connections only to/from the private network. Use this on the server side to reject
// connections from the public internet.
//
// auto restricted = network->restrictPeers({"192.0.2.0/24"}, {"192.0.2.3/32"});
//
// Allows connections only to/from 192.0.2.*, except 192.0.2.3 which is blocked.
//
// auto restricted = network->restrictPeers({"10.0.0.0/8", "10.1.2.3/32"}, {"10.1.2.0/24"});
//
// Allows connections to/from 10.*.*.*, with the exception of 10.1.2.* (which is denied), with an
// exception to the exception of 10.1.2.3 (which is allowed, because it is matched by an allow
// rule that is more specific than the deny rule).
};
// =======================================================================================
......@@ -470,13 +531,21 @@ public:
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<ConnectionReceiver> wrapListenSocketFd(Fd fd, uint flags = 0) = 0;
class NetworkFilter {
public:
virtual bool shouldAllow(const struct sockaddr* addr, uint addrlen) = 0;
// Returns true if incoming connections or datagrams from the given peer should be accepted.
// If false, they will be dropped. This is used to implement kj::Network::restrictPeers().
};
virtual Own<ConnectionReceiver> wrapListenSocketFd(
Fd fd, NetworkFilter& filter, uint flags = 0) = 0;
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual Own<DatagramPort> wrapDatagramSocketFd(Fd fd, uint flags = 0);
virtual Own<DatagramPort> wrapDatagramSocketFd(Fd fd, NetworkFilter& filter, uint flags = 0);
virtual Timer& getTimer() = 0;
// Returns a `Timer` based on real time. Time does not pass while event handlers are running --
......
......@@ -1280,7 +1280,7 @@ public:
return ArrayPtr<const T>(ptr, size_);
}
inline size_t size() const { return size_; }
inline constexpr size_t size() const { return size_; }
inline const T& operator[](size_t index) const {
KJ_IREQUIRE(index < size_, "Out-of-bounds ArrayPtr access.");
return ptr[index];
......@@ -1294,8 +1294,8 @@ public:
inline T* end() { return ptr + size_; }
inline T& front() { return *ptr; }
inline T& back() { return *(ptr + size_ - 1); }
inline const T* begin() const { return ptr; }
inline const T* end() const { return ptr + size_; }
inline constexpr const T* begin() const { return ptr; }
inline constexpr const T* end() const { return ptr + size_; }
inline const T& front() const { return *ptr; }
inline const T& back() const { return *(ptr + size_ - 1); }
......
......@@ -443,6 +443,8 @@ private:
class TlsNetwork: public kj::Network {
public:
TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {}
TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner)
: tls(tls), inner(*inner), ownInner(kj::mv(inner)) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override {
kj::String hostname;
......@@ -463,9 +465,19 @@ public:
KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames");
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
// TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
// Or is it better to let people do that via the TlsContext? A neat thing about
// restrictPeers() is that it's easy to make user-configurable.
return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny));
}
private:
TlsContext& tls;
kj::Network& inner;
kj::Own<kj::Network> ownInner;
};
} // namespace
......
......@@ -173,6 +173,18 @@ TEST(String, ToString) {
}
#endif
KJ_TEST("string literals with _kj suffix") {
static constexpr StringPtr FOO = "foo"_kj;
KJ_EXPECT(FOO == "foo", FOO);
KJ_EXPECT(FOO[3] == 0);
KJ_EXPECT("foo\0bar"_kj == StringPtr("foo\0bar", 7));
static constexpr ArrayPtr<const char> ARR = "foo"_kj;
KJ_EXPECT(ARR.size() == 3);
KJ_EXPECT(kj::str(ARR) == "foo");
}
} // namespace
} // namespace _ (private)
} // namespace kj
......@@ -31,11 +31,29 @@
#include <string.h>
namespace kj {
class StringPtr;
class String;
class StringPtr;
class String;
class StringTree; // string-tree.h
}
class StringTree; // string-tree.h
constexpr kj::StringPtr operator "" _kj(const char* str, size_t n);
// You can append _kj to a string literal to make its type be StringPtr. There are a few cases
// where you must do this for correctness:
// - When you want to declare a constexpr StringPtr. Without _kj, this is a compile error.
// - When you want to initialize a static/global StringPtr from a string literal without forcing
// global constructor code to run at dynamic initialization time.
// - When you have a string literal that contains NUL characters. Without _kj, the string will
// be considered to end at the first NUL.
// - When you want to initialize an ArrayPtr<const char> from a string literal, without including
// the NUL terminator in the data. (Initializing an ArrayPtr from a regular string literal is
// a compile error specifically due to this ambiguity.)
//
// In other cases, there should be no difference between initializing a StringPtr from a regular
// string literal vs. one with _kj (assuming the compiler is able to optimize away strlen() on a
// string literal).
namespace kj {
// Our STL string SFINAE trick does not work with GCC 4.7, but it works with Clang and GCC 4.8, so
// we'll just preprocess it out if not supported.
......@@ -75,8 +93,8 @@ public:
// those who don't want it.
#endif
inline operator ArrayPtr<const char>() const;
inline ArrayPtr<const char> asArray() const;
inline constexpr operator ArrayPtr<const char>() const;
inline constexpr ArrayPtr<const char> asArray() const;
inline ArrayPtr<const byte> asBytes() const { return asArray().asBytes(); }
// Result does not include NUL terminator.
......@@ -121,9 +139,11 @@ public:
// Overflowed floating numbers return inf.
private:
inline StringPtr(ArrayPtr<const char> content): content(content) {}
inline constexpr StringPtr(ArrayPtr<const char> content): content(content) {}
ArrayPtr<const char> content;
friend constexpr kj::StringPtr (::operator "" _kj)(const char* str, size_t n);
};
inline bool operator==(const char* a, const StringPtr& b) { return b == a; }
......@@ -427,12 +447,12 @@ inline String Stringifier::operator*(const Array<T>& arr) const {
inline StringPtr::StringPtr(const String& value): content(value.begin(), value.size() + 1) {}
inline StringPtr::operator ArrayPtr<const char>() const {
return content.slice(0, content.size() - 1);
inline constexpr StringPtr::operator ArrayPtr<const char>() const {
return ArrayPtr<const char>(content.begin(), content.size() - 1);
}
inline ArrayPtr<const char> StringPtr::asArray() const {
return content.slice(0, content.size() - 1);
inline constexpr ArrayPtr<const char> StringPtr::asArray() const {
return ArrayPtr<const char>(content.begin(), content.size() - 1);
}
inline bool StringPtr::operator==(const StringPtr& other) const {
......@@ -531,4 +551,8 @@ inline String heapString(ArrayPtr<const char> value) {
} // namespace kj
constexpr kj::StringPtr operator "" _kj(const char* str, size_t n) {
return kj::StringPtr(kj::ArrayPtr<const char>(str, n + 1));
};
#endif // KJ_STRING_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment