Commit 81b9bccc authored by Kenton Varda's avatar Kenton Varda

Fix Windows problems with restrictPeers().

parent 9137a15f
...@@ -551,7 +551,9 @@ KJ_TEST("Network::restrictPeers()") { ...@@ -551,7 +551,9 @@ KJ_TEST("Network::restrictPeers()") {
auto restrictedNetwork = network.restrictPeers({"public"}); auto restrictedNetwork = network.restrictPeers({"public"});
KJ_EXPECT(tryParse(w, *restrictedNetwork, "8.8.8.8") == "8.8.8.8:0"); KJ_EXPECT(tryParse(w, *restrictedNetwork, "8.8.8.8") == "8.8.8.8:0");
#if !_WIN32
KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo")); KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo"));
#endif
auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w); auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w);
......
...@@ -642,18 +642,25 @@ public: ...@@ -642,18 +642,25 @@ public:
} }
} }
static SocketAddress getLocalAddress(int sockfd) { static SocketAddress getLocalAddress(SOCKET sockfd) {
SocketAddress result; SocketAddress result;
result.addrlen = sizeof(addr); result.addrlen = sizeof(addr);
KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen)); KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
return result; return result;
} }
static SocketAddress getPeerAddress(SOCKET sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) { bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen); return filter.shouldAllow(&addr.generic, addrlen);
} }
bool parseAllowedBy(const _::NetworkFilter& filter) { bool parseAllowedBy(_::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic, addrlen); return filter.shouldAllowParse(&addr.generic, addrlen);
} }
...@@ -881,7 +888,8 @@ public: ...@@ -881,7 +888,8 @@ public:
} }
return op->onComplete().then(mvCapture(result, mvCapture(scratch, return op->onComplete().then(mvCapture(result, mvCapture(scratch,
[this](Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult) [this,newFd]
(Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
-> Promise<Own<AsyncIoStream>> { -> Promise<Own<AsyncIoStream>> {
if (ioResult.errorCode != ERROR_SUCCESS) { if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; } KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
...@@ -891,11 +899,13 @@ public: ...@@ -891,11 +899,13 @@ public:
reinterpret_cast<char*>(&me), sizeof(me)); reinterpret_cast<char*>(&me), sizeof(me));
} }
auto addr = reinterpret_cast<struct sockaddr*>(scratch.begin() + 128); // Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
size_t addrlen = addr->sa_family == AF_INET // named `scratch`). However, the format in which it writes these is undocumented, and
? sizeof(struct sockaddr_in) // doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
: sizeof(struct sockaddr_in6); // why they require the buffer to have space for it in the first place. We'll need to call
if (filter.shouldAllow(addr, addrlen)) { // getpeername() to get the address.
auto addr = SocketAddress::getPeerAddress(newFd);
if (addr.allowedBy(filter)) {
return kj::mv(stream); return kj::mv(stream);
} else { } else {
return accept(); return accept();
......
...@@ -268,7 +268,7 @@ CidrRange::CidrRange(int family, ArrayPtr<const byte> bits, uint bitCount) ...@@ -268,7 +268,7 @@ CidrRange::CidrRange(int family, ArrayPtr<const byte> bits, uint bitCount)
KJ_REQUIRE(bits.size() * 8 >= bitCount); KJ_REQUIRE(bits.size() * 8 >= bitCount);
size_t byteCount = (bitCount + 7) / 8; size_t byteCount = (bitCount + 7) / 8;
memcpy(this->bits, bits.begin(), byteCount); memcpy(this->bits, bits.begin(), byteCount);
memset(this->bits + byteCount, 0, sizeof(bits) - byteCount); memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount);
zeroIrrelevantBits(); zeroIrrelevantBits();
} }
...@@ -486,7 +486,7 @@ NetworkFilter::NetworkFilter(ArrayPtr<const StringPtr> allow, ArrayPtr<const Str ...@@ -486,7 +486,7 @@ NetworkFilter::NetworkFilter(ArrayPtr<const StringPtr> allow, ArrayPtr<const Str
} }
bool NetworkFilter::shouldAllow(const struct sockaddr* addr, uint addrlen) { bool NetworkFilter::shouldAllow(const struct sockaddr* addr, uint addrlen) {
KJ_REQUIRE(addrlen > sizeof(addr->sa_family)); KJ_REQUIRE(addrlen >= sizeof(addr->sa_family));
#if !_WIN32 #if !_WIN32
if (addr->sa_family == AF_UNIX) { if (addr->sa_family == AF_UNIX) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment