Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
C
capnproto
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
capnproto
Commits
ac6b5d30
Commit
ac6b5d30
authored
Sep 15, 2017
by
Kenton Varda
Committed by
GitHub
Sep 15, 2017
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #555 from capnproto/network-filter
Extend kj::Network interface for easy SSRF protection
parents
b3dec708
04ff4676
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
983 additions
and
98 deletions
+983
-98
ez-rpc.c++
c++/src/capnp/ez-rpc.c++
+15
-1
async-io-internal.h
c++/src/kj/async-io-internal.h
+91
-0
async-io-test.c++
c++/src/kj/async-io-test.c++
+164
-5
async-io-unix.c++
c++/src/kj/async-io-unix.c++
+119
-42
async-io-win32.c++
c++/src/kj/async-io-win32.c++
+98
-34
async-io.c++
c++/src/kj/async-io.c++
+364
-1
async-io.h
c++/src/kj/async-io.h
+71
-2
common.h
c++/src/kj/common.h
+3
-3
tls.c++
c++/src/kj/compat/tls.c++
+12
-0
string-test.c++
c++/src/kj/string-test.c++
+12
-0
string.h
c++/src/kj/string.h
+34
-10
No files found.
c++/src/capnp/ez-rpc.c++
View file @
ac6b5d30
...
@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
...
@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
// =======================================================================================
// =======================================================================================
namespace
{
class
DummyFilter
:
public
kj
::
LowLevelAsyncIoProvider
::
NetworkFilter
{
public
:
bool
shouldAllow
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
override
{
return
true
;
}
};
static
DummyFilter
DUMMY_FILTER
;
}
// namespace
struct
EzRpcServer
::
Impl
final
:
public
SturdyRefRestorer
<
AnyPointer
>
,
struct
EzRpcServer
::
Impl
final
:
public
SturdyRefRestorer
<
AnyPointer
>
,
public
kj
::
TaskSet
::
ErrorHandler
{
public
kj
::
TaskSet
::
ErrorHandler
{
Capability
::
Client
mainInterface
;
Capability
::
Client
mainInterface
;
...
@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
...
@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
context
(
EzRpcContext
::
getThreadLocal
()),
context
(
EzRpcContext
::
getThreadLocal
()),
portPromise
(
kj
::
Promise
<
uint
>
(
port
).
fork
()),
portPromise
(
kj
::
Promise
<
uint
>
(
port
).
fork
()),
tasks
(
*
this
)
{
tasks
(
*
this
)
{
acceptLoop
(
context
->
getLowLevelIoProvider
().
wrapListenSocketFd
(
socketFd
),
readerOpts
);
acceptLoop
(
context
->
getLowLevelIoProvider
().
wrapListenSocketFd
(
socketFd
,
DUMMY_FILTER
),
readerOpts
);
}
}
void
acceptLoop
(
kj
::
Own
<
kj
::
ConnectionReceiver
>&&
listener
,
ReaderOptions
readerOpts
)
{
void
acceptLoop
(
kj
::
Own
<
kj
::
ConnectionReceiver
>&&
listener
,
ReaderOptions
readerOpts
)
{
...
...
c++/src/kj/async-io-internal.h
0 → 100644
View file @
ac6b5d30
// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef KJ_ASYNC_IO_INTERNAL_H_
#define KJ_ASYNC_IO_INTERNAL_H_
#include "string.h"
#include "vector.h"
#include "async-io.h"
#include <stdint.h>
struct
sockaddr
;
struct
sockaddr_un
;
namespace
kj
{
namespace
_
{
// private
// =======================================================================================
#if !_WIN32
kj
::
ArrayPtr
<
const
char
>
safeUnixPath
(
const
struct
sockaddr_un
*
addr
,
uint
addrlen
);
// sockaddr_un::sun_path is not required to have a NUL terminator! Thus to be safe unix address
// paths MUST be read using this function.
#endif
class
CidrRange
{
public
:
CidrRange
(
StringPtr
pattern
);
static
CidrRange
inet4
(
ArrayPtr
<
const
byte
>
bits
,
uint
bitCount
);
static
CidrRange
inet6
(
ArrayPtr
<
const
uint16_t
>
prefix
,
ArrayPtr
<
const
uint16_t
>
suffix
,
uint
bitCount
);
// Zeros are inserted between `prefix` and `suffix` to extend the address to 128 bits.
uint
getSpecificity
()
const
{
return
bitCount
;
}
bool
matches
(
const
struct
sockaddr
*
addr
)
const
;
bool
matchesFamily
(
int
family
)
const
;
String
toString
()
const
;
private
:
int
family
;
byte
bits
[
16
];
uint
bitCount
;
// how many bits in `bits` need to match
CidrRange
(
int
family
,
ArrayPtr
<
const
byte
>
bits
,
uint
bitCount
);
void
zeroIrrelevantBits
();
};
class
NetworkFilter
:
public
LowLevelAsyncIoProvider
::
NetworkFilter
{
public
:
NetworkFilter
();
NetworkFilter
(
ArrayPtr
<
const
StringPtr
>
allow
,
ArrayPtr
<
const
StringPtr
>
deny
,
NetworkFilter
&
next
);
bool
shouldAllow
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
override
;
bool
shouldAllowParse
(
const
struct
sockaddr
*
addr
,
uint
addrlen
);
private
:
Vector
<
CidrRange
>
allowCidrs
;
Vector
<
CidrRange
>
denyCidrs
;
bool
allowUnix
;
bool
allowAbstractUnix
;
kj
::
Maybe
<
NetworkFilter
&>
next
;
};
}
// namespace _ (private)
}
// namespace kj
#endif // KJ_ASYNC_IO_INTERNAL_H_
c++/src/kj/async-io-test.c++
View file @
ac6b5d30
...
@@ -19,17 +19,27 @@
...
@@ -19,17 +19,27 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "debug.h"
#include "debug.h"
#include <kj/compat/gtest.h>
#include <kj/compat/gtest.h>
#include <sys/types.h>
#include <sys/types.h>
#if _WIN32
#if _WIN32
#include <ws2tcpip.h>
#include <ws2tcpip.h>
#include "windows-sanity.h"
#include "windows-sanity.h"
#define inet_pton InetPtonA
#define inet_ntop InetNtopA
#else
#else
#include <netdb.h>
#include <netdb.h>
#include <unistd.h>
#include <unistd.h>
#include <fcntl.h>
#include <fcntl.h>
#include <arpa/inet.h>
#endif
#endif
namespace
kj
{
namespace
kj
{
...
@@ -77,12 +87,13 @@ String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint por
...
@@ -77,12 +87,13 @@ String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint por
return
network
.
parseAddress
(
text
,
portHint
).
wait
(
waitScope
)
->
toString
();
return
network
.
parseAddress
(
text
,
portHint
).
wait
(
waitScope
)
->
toString
();
}
}
bool
systemSupportsAddress
(
StringPtr
addr
)
{
bool
systemSupportsAddress
(
StringPtr
addr
,
StringPtr
service
=
nullptr
)
{
// Can getaddrinfo() parse this addresses? This is only true if the address family (e.g., ipv6)
// Can getaddrinfo() parse this addresses? This is only true if the address family (e.g., ipv6)
// is configured on at least one interface. (The loopback interface usually has both ipv4 and
// is configured on at least one interface. (The loopback interface usually has both ipv4 and
// ipv6 configured, but not always.)
// ipv6 configured, but not always.)
struct
addrinfo
*
list
;
struct
addrinfo
*
list
;
int
status
=
getaddrinfo
(
addr
.
cStr
(),
nullptr
,
nullptr
,
&
list
);
int
status
=
getaddrinfo
(
addr
.
cStr
(),
service
==
nullptr
?
nullptr
:
service
.
cStr
(),
nullptr
,
&
list
);
if
(
status
==
0
)
{
if
(
status
==
0
)
{
freeaddrinfo
(
list
);
freeaddrinfo
(
list
);
return
true
;
return
true
;
...
@@ -91,7 +102,6 @@ bool systemSupportsAddress(StringPtr addr) {
...
@@ -91,7 +102,6 @@ bool systemSupportsAddress(StringPtr addr) {
}
}
}
}
TEST
(
AsyncIo
,
AddressParsing
)
{
TEST
(
AsyncIo
,
AddressParsing
)
{
auto
ioContext
=
setupAsyncIo
();
auto
ioContext
=
setupAsyncIo
();
auto
&
w
=
ioContext
.
waitScope
;
auto
&
w
=
ioContext
.
waitScope
;
...
@@ -110,7 +120,7 @@ TEST(AsyncIo, AddressParsing) {
...
@@ -110,7 +120,7 @@ TEST(AsyncIo, AddressParsing) {
// We can parse services by name...
// We can parse services by name...
//
//
// For some reason, Android and some various Linux distros do not support service names.
// For some reason, Android and some various Linux distros do not support service names.
if
(
systemSupportsAddress
(
"1.2.3.4
:
http"
))
{
if
(
systemSupportsAddress
(
"1.2.3.4
"
,
"
http"
))
{
EXPECT_EQ
(
"1.2.3.4:80"
,
tryParse
(
w
,
network
,
"1.2.3.4:http"
,
5678
));
EXPECT_EQ
(
"1.2.3.4:80"
,
tryParse
(
w
,
network
,
"1.2.3.4:http"
,
5678
));
EXPECT_EQ
(
"*:80"
,
tryParse
(
w
,
network
,
"*:http"
,
5678
));
EXPECT_EQ
(
"*:80"
,
tryParse
(
w
,
network
,
"*:http"
,
5678
));
}
else
{
}
else
{
...
@@ -122,7 +132,7 @@ TEST(AsyncIo, AddressParsing) {
...
@@ -122,7 +132,7 @@ TEST(AsyncIo, AddressParsing) {
if
(
systemSupportsAddress
(
"::"
))
{
if
(
systemSupportsAddress
(
"::"
))
{
EXPECT_EQ
(
"[::]:123"
,
tryParse
(
w
,
network
,
"0::0"
,
123
));
EXPECT_EQ
(
"[::]:123"
,
tryParse
(
w
,
network
,
"0::0"
,
123
));
EXPECT_EQ
(
"[12ab:cd::34]:321"
,
tryParse
(
w
,
network
,
"[12ab:cd:0::0:34]:321"
,
432
));
EXPECT_EQ
(
"[12ab:cd::34]:321"
,
tryParse
(
w
,
network
,
"[12ab:cd:0::0:34]:321"
,
432
));
if
(
systemSupportsAddress
(
"
[12ab:cd::34]:
http"
))
{
if
(
systemSupportsAddress
(
"
12ab:cd::34"
,
"
http"
))
{
EXPECT_EQ
(
"[::]:80"
,
tryParse
(
w
,
network
,
"[::]:http"
,
5678
));
EXPECT_EQ
(
"[::]:80"
,
tryParse
(
w
,
network
,
"[::]:http"
,
5678
));
EXPECT_EQ
(
"[12ab:cd::34]:80"
,
tryParse
(
w
,
network
,
"[12ab:cd::34]:http"
,
5678
));
EXPECT_EQ
(
"[12ab:cd::34]:80"
,
tryParse
(
w
,
network
,
"[12ab:cd::34]:http"
,
5678
));
}
else
{
}
else
{
...
@@ -412,5 +422,154 @@ TEST(AsyncIo, AbstractUnixSocket) {
...
@@ -412,5 +422,154 @@ TEST(AsyncIo, AbstractUnixSocket) {
#endif // __linux__
#endif // __linux__
KJ_TEST
(
"CIDR parsing"
)
{
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.3.4/16"
).
toString
()
==
"1.2.0.0/16"
);
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.255.4/18"
).
toString
()
==
"1.2.192.0/18"
);
KJ_EXPECT
(
_
::
CidrRange
(
"1234::abcd:ffff:ffff/98"
).
toString
()
==
"1234::abcd:c000:0/98"
);
KJ_EXPECT
(
_
::
CidrRange
::
inet4
({
1
,
2
,
255
,
4
},
18
).
toString
()
==
"1.2.192.0/18"
);
KJ_EXPECT
(
_
::
CidrRange
::
inet6
({
0x1234
,
0x5678
},
{
0xabcd
,
0xffff
,
0xffff
},
98
).
toString
()
==
"1234:5678::abcd:c000:0/98"
);
union
{
struct
sockaddr
addr
;
struct
sockaddr_in
addr4
;
struct
sockaddr_in6
addr6
;
};
memset
(
&
addr6
,
0
,
sizeof
(
addr6
));
{
addr4
.
sin_family
=
AF_INET
;
addr4
.
sin_addr
.
s_addr
=
htonl
(
0x0102dfff
);
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.255.255/18"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"1.2.255.255/19"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.0.0/16"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"1.3.0.0/16"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.223.255/32"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"0.0.0.0/0"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"::/0"
).
matches
(
&
addr
));
}
{
addr4
.
sin_family
=
AF_INET6
;
byte
bytes
[
16
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
};
memcpy
(
addr6
.
sin6_addr
.
s6_addr
,
bytes
,
16
);
KJ_EXPECT
(
_
::
CidrRange
(
"0102:03ff::/24"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"0102:02ff::/24"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"0102:02ff::/23"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"0102:0304:0506:0708:090a:0b0c:0d0e:0f10/128"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"::/0"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"0.0.0.0/0"
).
matches
(
&
addr
));
}
{
addr4
.
sin_family
=
AF_INET6
;
inet_pton
(
AF_INET6
,
"::ffff:1.2.223.255"
,
&
addr6
.
sin6_addr
);
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.255.255/18"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"1.2.255.255/19"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.0.0/16"
).
matches
(
&
addr
));
KJ_EXPECT
(
!
_
::
CidrRange
(
"1.3.0.0/16"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"1.2.223.255/32"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"0.0.0.0/0"
).
matches
(
&
addr
));
KJ_EXPECT
(
_
::
CidrRange
(
"::/0"
).
matches
(
&
addr
));
}
}
bool
allowed4
(
_
::
NetworkFilter
&
filter
,
StringPtr
addrStr
)
{
struct
sockaddr_in
addr
;
memset
(
&
addr
,
0
,
sizeof
(
addr
));
addr
.
sin_family
=
AF_INET
;
inet_pton
(
AF_INET
,
addrStr
.
cStr
(),
&
addr
.
sin_addr
);
return
filter
.
shouldAllow
(
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
sizeof
(
addr
));
}
bool
allowed6
(
_
::
NetworkFilter
&
filter
,
StringPtr
addrStr
)
{
struct
sockaddr_in6
addr
;
memset
(
&
addr
,
0
,
sizeof
(
addr
));
addr
.
sin6_family
=
AF_INET6
;
inet_pton
(
AF_INET6
,
addrStr
.
cStr
(),
&
addr
.
sin6_addr
);
return
filter
.
shouldAllow
(
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
sizeof
(
addr
));
}
KJ_TEST
(
"NetworkFilter"
)
{
_
::
NetworkFilter
base
;
KJ_EXPECT
(
allowed4
(
base
,
"8.8.8.8"
));
KJ_EXPECT
(
!
allowed4
(
base
,
"240.1.2.3"
));
{
_
::
NetworkFilter
filter
({
"public"
},
{},
base
);
KJ_EXPECT
(
allowed4
(
filter
,
"8.8.8.8"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"240.1.2.3"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"192.168.0.1"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"10.1.2.3"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"127.0.0.1"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"0.0.0.0"
));
KJ_EXPECT
(
allowed6
(
filter
,
"2400:cb00:2048:1::c629:d7a2"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"fc00::1234"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"::1"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"::"
));
}
{
_
::
NetworkFilter
filter
({
"private"
},
{
"local"
},
base
);
KJ_EXPECT
(
!
allowed4
(
filter
,
"8.8.8.8"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"240.1.2.3"
));
KJ_EXPECT
(
allowed4
(
filter
,
"192.168.0.1"
));
KJ_EXPECT
(
allowed4
(
filter
,
"10.1.2.3"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"127.0.0.1"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"0.0.0.0"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"2400:cb00:2048:1::c629:d7a2"
));
KJ_EXPECT
(
allowed6
(
filter
,
"fc00::1234"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"::1"
));
KJ_EXPECT
(
!
allowed6
(
filter
,
"::"
));
}
{
_
::
NetworkFilter
filter
({
"1.0.0.0/8"
,
"1.2.3.0/24"
},
{
"1.2.0.0/16"
,
"1.2.3.4/32"
},
base
);
KJ_EXPECT
(
!
allowed4
(
filter
,
"8.8.8.8"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"240.1.2.3"
));
KJ_EXPECT
(
allowed4
(
filter
,
"1.0.0.1"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"1.2.2.1"
));
KJ_EXPECT
(
allowed4
(
filter
,
"1.2.3.1"
));
KJ_EXPECT
(
!
allowed4
(
filter
,
"1.2.3.4"
));
}
}
KJ_TEST
(
"Network::restrictPeers()"
)
{
auto
ioContext
=
setupAsyncIo
();
auto
&
w
=
ioContext
.
waitScope
;
auto
&
network
=
ioContext
.
provider
->
getNetwork
();
auto
restrictedNetwork
=
network
.
restrictPeers
({
"public"
});
KJ_EXPECT
(
tryParse
(
w
,
*
restrictedNetwork
,
"8.8.8.8"
)
==
"8.8.8.8:0"
);
#if !_WIN32
KJ_EXPECT_THROW_MESSAGE
(
"restrictPeers"
,
tryParse
(
w
,
*
restrictedNetwork
,
"unix:/foo"
));
#endif
auto
addr
=
restrictedNetwork
->
parseAddress
(
"127.0.0.1"
).
wait
(
w
);
auto
listener
=
addr
->
listen
();
auto
acceptTask
=
listener
->
accept
()
.
then
([](
kj
::
Own
<
kj
::
AsyncIoStream
>
)
{
KJ_FAIL_EXPECT
(
"should not have received connection"
);
}).
eagerlyEvaluate
(
nullptr
);
KJ_EXPECT_THROW_MESSAGE
(
"restrictPeers"
,
addr
->
connect
().
wait
(
w
));
// We can connect to the listener but the connection will be immediately closed.
auto
addr2
=
network
.
parseAddress
(
"127.0.0.1"
,
listener
->
getPort
()).
wait
(
w
);
auto
conn
=
addr2
->
connect
().
wait
(
w
);
KJ_EXPECT
(
conn
->
readAllText
().
wait
(
w
)
==
""
);
}
}
// namespace
}
// namespace
}
// namespace kj
}
// namespace kj
c++/src/kj/async-io-unix.c++
View file @
ac6b5d30
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
// For Win32 implementation, see async-io-win32.c++.
// For Win32 implementation, see async-io-win32.c++.
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-unix.h"
#include "async-unix.h"
#include "debug.h"
#include "debug.h"
#include "thread.h"
#include "thread.h"
...
@@ -449,10 +450,11 @@ public:
...
@@ -449,10 +450,11 @@ public:
return
str
(
'['
,
buffer
,
"]:"
,
ntohs
(
addr
.
inet6
.
sin6_port
));
return
str
(
'['
,
buffer
,
"]:"
,
ntohs
(
addr
.
inet6
.
sin6_port
));
}
}
case
AF_UNIX
:
{
case
AF_UNIX
:
{
if
(
addr
.
unixDomain
.
sun_path
[
0
]
==
'\0'
)
{
auto
path
=
_
::
safeUnixPath
(
&
addr
.
unixDomain
,
addrlen
);
return
str
(
"unix-abstract:"
,
addr
.
unixDomain
.
sun_path
+
1
);
if
(
path
.
size
()
>
0
&&
path
[
0
]
==
'\0'
)
{
return
str
(
"unix-abstract:"
,
path
.
slice
(
1
,
path
.
size
()));
}
else
{
}
else
{
return
str
(
"unix:"
,
addr
.
unixDomain
.
sun_
path
);
return
str
(
"unix:"
,
path
);
}
}
}
}
default
:
default
:
...
@@ -461,11 +463,12 @@ public:
...
@@ -461,11 +463,12 @@ public:
}
}
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
);
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
);
// Perform a DNS lookup.
// Perform a DNS lookup.
static
Promise
<
Array
<
SocketAddress
>>
parse
(
static
Promise
<
Array
<
SocketAddress
>>
parse
(
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// TODO(someday): Allow commas in `str`.
// TODO(someday): Allow commas in `str`.
SocketAddress
result
;
SocketAddress
result
;
...
@@ -480,6 +483,12 @@ public:
...
@@ -480,6 +483,12 @@ public:
result
.
addr
.
unixDomain
.
sun_family
=
AF_UNIX
;
result
.
addr
.
unixDomain
.
sun_family
=
AF_UNIX
;
strcpy
(
result
.
addr
.
unixDomain
.
sun_path
,
path
.
cStr
());
strcpy
(
result
.
addr
.
unixDomain
.
sun_path
,
path
.
cStr
());
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"unix sockets blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -495,6 +504,12 @@ public:
...
@@ -495,6 +504,12 @@ public:
// NULL terminator so that we can safely read it back in toString
// NULL terminator so that we can safely read it back in toString
memcpy
(
result
.
addr
.
unixDomain
.
sun_path
+
1
,
path
.
cStr
(),
path
.
size
()
+
1
);
memcpy
(
result
.
addr
.
unixDomain
.
sun_path
+
1
,
path
.
cStr
(),
path
.
size
()
+
1
);
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"abstract unix sockets blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -547,7 +562,8 @@ public:
...
@@ -547,7 +562,8 @@ public:
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
// Not a number. Maybe it's a service name. Fall back to DNS.
// Not a number. Maybe it's a service name. Fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
,
filter
);
}
}
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
}
else
{
}
else
{
...
@@ -569,6 +585,7 @@ public:
...
@@ -569,6 +585,7 @@ public:
result
.
addr
.
inet6
.
sin6_family
=
AF_INET6
;
result
.
addr
.
inet6
.
sin6_family
=
AF_INET6
;
result
.
addr
.
inet6
.
sin6_port
=
htons
(
port
);
result
.
addr
.
inet6
.
sin6_port
=
htons
(
port
);
#endif
#endif
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -597,13 +614,18 @@ public:
...
@@ -597,13 +614,18 @@ public:
switch
(
inet_pton
(
af
,
buffer
,
addrTarget
))
{
switch
(
inet_pton
(
af
,
buffer
,
addrTarget
))
{
case
1
:
{
case
1
:
{
// success.
// success.
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"address family blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
}
}
case
0
:
case
0
:
// It's apparently not a simple address... fall back to DNS.
// It's apparently not a simple address... fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
,
filter
);
default
:
default
:
KJ_FAIL_SYSCALL
(
"inet_pton"
,
errno
,
af
,
addrPart
);
KJ_FAIL_SYSCALL
(
"inet_pton"
,
errno
,
af
,
addrPart
);
}
}
...
@@ -616,6 +638,14 @@ public:
...
@@ -616,6 +638,14 @@ public:
return
result
;
return
result
;
}
}
bool
allowedBy
(
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllow
(
&
addr
.
generic
,
addrlen
);
}
bool
parseAllowedBy
(
_
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllowParse
(
&
addr
.
generic
,
addrlen
);
}
private
:
private
:
SocketAddress
()
:
addrlen
(
0
)
{
SocketAddress
()
:
addrlen
(
0
)
{
memset
(
&
addr
,
0
,
sizeof
(
addr
));
memset
(
&
addr
,
0
,
sizeof
(
addr
));
...
@@ -640,8 +670,9 @@ class SocketAddress::LookupReader {
...
@@ -640,8 +670,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
// getaddrinfo.
public
:
public
:
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
)
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
,
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
))
{}
_
::
NetworkFilter
&
filter
)
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
)),
filter
(
filter
)
{}
~
LookupReader
()
{
~
LookupReader
()
{
if
(
thread
)
thread
->
detach
();
if
(
thread
)
thread
->
detach
();
...
@@ -654,7 +685,7 @@ public:
...
@@ -654,7 +685,7 @@ public:
thread
=
nullptr
;
thread
=
nullptr
;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
// anyway.
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no addresses."
)
{
break
;
}
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no
permitted
addresses."
)
{
break
;
}
return
addresses
.
releaseAsArray
();
return
addresses
.
releaseAsArray
();
}
else
{
}
else
{
// getaddrinfo() can return multiple copies of the same address for several reasons.
// getaddrinfo() can return multiple copies of the same address for several reasons.
...
@@ -667,8 +698,10 @@ public:
...
@@ -667,8 +698,10 @@ public:
//
//
// So we instead resort to de-duping results.
// So we instead resort to de-duping results.
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
current
.
parseAllowedBy
(
filter
))
{
addresses
.
add
(
current
);
addresses
.
add
(
current
);
}
}
}
return
read
();
return
read
();
}
}
});
});
...
@@ -677,6 +710,7 @@ public:
...
@@ -677,6 +710,7 @@ public:
private
:
private
:
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
AsyncInputStream
>
input
;
kj
::
Own
<
AsyncInputStream
>
input
;
_
::
NetworkFilter
&
filter
;
SocketAddress
current
;
SocketAddress
current
;
kj
::
Vector
<
SocketAddress
>
addresses
;
kj
::
Vector
<
SocketAddress
>
addresses
;
std
::
set
<
SocketAddress
>
alreadySeen
;
std
::
set
<
SocketAddress
>
alreadySeen
;
...
@@ -688,7 +722,8 @@ struct SocketAddress::LookupParams {
...
@@ -688,7 +722,8 @@ struct SocketAddress::LookupParams {
};
};
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
// the only cross-platform DNS API and it is blocking.
//
//
...
@@ -773,7 +808,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -773,7 +808,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}
}));
}));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
)
,
filter
);
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
}
}
...
@@ -781,22 +816,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -781,22 +816,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFileDescriptor
{
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFileDescriptor
{
public
:
public
:
FdConnectionReceiver
(
UnixEventPort
&
eventPort
,
int
fd
,
uint
flags
)
FdConnectionReceiver
(
UnixEventPort
&
eventPort
,
int
fd
,
:
OwnedFileDescriptor
(
fd
,
flags
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFileDescriptor
(
fd
,
flags
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
)
{}
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
)
{}
Promise
<
Own
<
AsyncIoStream
>>
accept
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
accept
()
override
{
int
newFd
;
int
newFd
;
struct
sockaddr_storage
addr
;
socklen_t
addrlen
=
sizeof
(
addr
);
retry
:
retry
:
#if __linux__ && !__BIONIC__
#if __linux__ && !__BIONIC__
newFd
=
::
accept4
(
fd
,
nullptr
,
nullptr
,
SOCK_NONBLOCK
|
SOCK_CLOEXEC
);
newFd
=
::
accept4
(
fd
,
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
&
addrlen
,
SOCK_NONBLOCK
|
SOCK_CLOEXEC
);
#else
#else
newFd
=
::
accept
(
fd
,
nullptr
,
nullptr
);
newFd
=
::
accept
(
fd
,
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
&
addrlen
);
#endif
#endif
if
(
newFd
>=
0
)
{
if
(
newFd
>=
0
)
{
if
(
!
filter
.
shouldAllow
(
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
addrlen
))
{
// Drop disallowed address.
close
(
newFd
);
return
accept
();
}
else
{
return
Own
<
AsyncIoStream
>
(
heap
<
AsyncStreamFd
>
(
eventPort
,
newFd
,
NEW_FD_FLAGS
));
return
Own
<
AsyncIoStream
>
(
heap
<
AsyncStreamFd
>
(
eventPort
,
newFd
,
NEW_FD_FLAGS
));
}
}
else
{
}
else
{
int
error
=
errno
;
int
error
=
errno
;
...
@@ -849,13 +895,15 @@ public:
...
@@ -849,13 +895,15 @@ public:
public
:
public
:
UnixEventPort
&
eventPort
;
UnixEventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
UnixEventPort
::
FdObserver
observer
;
UnixEventPort
::
FdObserver
observer
;
};
};
class
DatagramPortImpl
final
:
public
DatagramPort
,
public
OwnedFileDescriptor
{
class
DatagramPortImpl
final
:
public
DatagramPort
,
public
OwnedFileDescriptor
{
public
:
public
:
DatagramPortImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
UnixEventPort
&
eventPort
,
int
fd
,
uint
flags
)
DatagramPortImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
UnixEventPort
&
eventPort
,
int
fd
,
:
OwnedFileDescriptor
(
fd
,
flags
),
lowLevel
(
lowLevel
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFileDescriptor
(
fd
,
flags
),
lowLevel
(
lowLevel
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
|
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
|
UnixEventPort
::
FdObserver
::
OBSERVE_WRITE
)
{}
UnixEventPort
::
FdObserver
::
OBSERVE_WRITE
)
{}
...
@@ -883,6 +931,7 @@ public:
...
@@ -883,6 +931,7 @@ public:
public
:
public
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
UnixEventPort
&
eventPort
;
UnixEventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
UnixEventPort
::
FdObserver
observer
;
UnixEventPort
::
FdObserver
observer
;
};
};
...
@@ -935,11 +984,13 @@ public:
...
@@ -935,11 +984,13 @@ public:
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
}));
}));
}
}
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
int
fd
,
uint
flags
=
0
)
override
{
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
flags
);
int
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
filter
,
flags
);
}
}
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
int
fd
,
uint
flags
=
0
)
override
{
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
return
heap
<
DatagramPortImpl
>
(
*
this
,
eventPort
,
fd
,
flags
);
int
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
DatagramPortImpl
>
(
*
this
,
eventPort
,
fd
,
filter
,
flags
);
}
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
...
@@ -956,12 +1007,14 @@ private:
...
@@ -956,12 +1007,14 @@ private:
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
public
:
public
:
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
Array
<
SocketAddress
>
addrs
)
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
:
lowLevel
(
lowLevel
),
addrs
(
kj
::
mv
(
addrs
))
{}
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
Array
<
SocketAddress
>
addrs
)
:
lowLevel
(
lowLevel
),
filter
(
filter
),
addrs
(
kj
::
mv
(
addrs
))
{}
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
promise
=
connectImpl
(
lowLevel
,
addrsCopy
);
auto
promise
=
connectImpl
(
lowLevel
,
filter
,
addrsCopy
);
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
}
}
...
@@ -988,7 +1041,7 @@ public:
...
@@ -988,7 +1041,7 @@ public:
KJ_SYSCALL
(
::
listen
(
fd
,
SOMAXCONN
));
KJ_SYSCALL
(
::
listen
(
fd
,
SOMAXCONN
));
}
}
return
lowLevel
.
wrapListenSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapListenSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
...
@@ -1011,11 +1064,11 @@ public:
...
@@ -1011,11 +1064,11 @@ public:
addrs
[
0
].
bind
(
fd
);
addrs
[
0
].
bind
(
fd
);
}
}
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
NetworkAddress
>
clone
()
override
{
Own
<
NetworkAddress
>
clone
()
override
{
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
kj
::
heapArray
(
addrs
.
asPtr
()));
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
heapArray
(
addrs
.
asPtr
()));
}
}
String
toString
()
override
{
String
toString
()
override
{
...
@@ -1029,26 +1082,33 @@ public:
...
@@ -1029,26 +1082,33 @@ public:
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Array
<
SocketAddress
>
addrs
;
Array
<
SocketAddress
>
addrs
;
uint
counter
=
0
;
uint
counter
=
0
;
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
KJ_ASSERT
(
addrs
.
size
()
>
0
);
KJ_ASSERT
(
addrs
.
size
()
>
0
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
return
kj
::
evalNow
([
&
]()
{
return
kj
::
evalNow
([
&
]()
->
Promise
<
Own
<
AsyncIoStream
>>
{
if
(
!
addrs
[
0
].
allowedBy
(
filter
))
{
return
KJ_EXCEPTION
(
FAILED
,
"connect() blocked by restrictPeers()"
);
}
else
{
return
lowLevel
.
wrapConnectingSocketFd
(
return
lowLevel
.
wrapConnectingSocketFd
(
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
}
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Success, pass along.
// Success, pass along.
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
},
[
&
lowLevel
,
addrs
](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
},
[
&
lowLevel
,
&
filter
,
addrs
](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Connect failed.
// Connect failed.
if
(
addrs
.
size
()
>
1
)
{
if
(
addrs
.
size
()
>
1
)
{
// Try the next address instead.
// Try the next address instead.
return
connectImpl
(
lowLevel
,
addrs
.
slice
(
1
,
addrs
.
size
()));
return
connectImpl
(
lowLevel
,
filter
,
addrs
.
slice
(
1
,
addrs
.
size
()));
}
else
{
}
else
{
// No more addresses to try, so propagate the exception.
// No more addresses to try, so propagate the exception.
return
kj
::
mv
(
exception
);
return
kj
::
mv
(
exception
);
...
@@ -1060,25 +1120,35 @@ private:
...
@@ -1060,25 +1120,35 @@ private:
class
SocketNetwork
final
:
public
Network
{
class
SocketNetwork
final
:
public
Network
{
public
:
public
:
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
SocketNetwork
&
parent
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
)
:
lowLevel
(
parent
.
lowLevel
),
filter
(
allow
,
deny
,
parent
.
filter
)
{}
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
auto
&
lowLevelCopy
=
lowLevel
;
return
evalLater
(
mvCapture
(
heapString
(
addr
),
[
this
,
portHint
](
String
&&
addr
)
{
return
evalLater
(
mvCapture
(
heapString
(
addr
),
return
SocketAddress
::
parse
(
lowLevel
,
addr
,
portHint
,
filter
);
[
&
lowLevelCopy
,
portHint
](
String
&&
addr
)
{
})).
then
([
this
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
SocketAddress
::
parse
(
lowLevelCopy
,
addr
,
portHint
);
return
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
mv
(
addresses
));
})).
then
([
&
lowLevelCopy
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
heap
<
NetworkAddressImpl
>
(
lowLevelCopy
,
kj
::
mv
(
addresses
));
});
});
}
}
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
array
.
finish
()));
KJ_REQUIRE
(
array
[
0
].
allowedBy
(
filter
),
"address blocked by restrictPeers()"
)
{
break
;
}
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
array
.
finish
()));
}
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
override
{
return
heap
<
SocketNetwork
>
(
*
this
,
allow
,
deny
);
}
}
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
_
::
NetworkFilter
filter
;
};
};
// =======================================================================================
// =======================================================================================
...
@@ -1189,10 +1259,16 @@ public:
...
@@ -1189,10 +1259,16 @@ public:
return
receive
();
return
receive
();
});
});
}
else
{
}
else
{
if
(
!
port
.
filter
.
shouldAllow
(
reinterpret_cast
<
const
struct
sockaddr
*>
(
msg
.
msg_name
),
msg
.
msg_namelen
))
{
// Ignore message from disallowed source.
return
receive
();
}
receivedSize
=
n
;
receivedSize
=
n
;
contentTruncated
=
msg
.
msg_flags
&
MSG_TRUNC
;
contentTruncated
=
msg
.
msg_flags
&
MSG_TRUNC
;
source
.
emplace
(
port
.
lowLevel
,
msg
.
msg_name
,
msg
.
msg_namelen
);
source
.
emplace
(
port
.
lowLevel
,
port
.
filter
,
msg
.
msg_name
,
msg
.
msg_namelen
);
ancillaryList
.
resize
(
0
);
ancillaryList
.
resize
(
0
);
ancillaryTruncated
=
msg
.
msg_flags
&
MSG_CTRUNC
;
ancillaryTruncated
=
msg
.
msg_flags
&
MSG_CTRUNC
;
...
@@ -1250,9 +1326,10 @@ private:
...
@@ -1250,9 +1326,10 @@ private:
bool
ancillaryTruncated
=
false
;
bool
ancillaryTruncated
=
false
;
struct
StoredAddress
{
struct
StoredAddress
{
StoredAddress
(
LowLevelAsyncIoProvider
&
lowLevel
,
const
void
*
sockaddr
,
uint
length
)
StoredAddress
(
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
const
void
*
sockaddr
,
uint
length
)
:
raw
(
sockaddr
,
length
),
:
raw
(
sockaddr
,
length
),
abstract
(
lowLevel
,
Array
<
SocketAddress
>
(
&
raw
,
1
,
NullArrayDisposer
::
instance
))
{}
abstract
(
lowLevel
,
filter
,
Array
<
SocketAddress
>
(
&
raw
,
1
,
NullArrayDisposer
::
instance
))
{}
SocketAddress
raw
;
SocketAddress
raw
;
NetworkAddressImpl
abstract
;
NetworkAddressImpl
abstract
;
...
...
c++/src/kj/async-io-win32.c++
View file @
ac6b5d30
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#define _WIN32_WINNT 0x0600
#define _WIN32_WINNT 0x0600
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h"
#include "async-win32.h"
#include "debug.h"
#include "debug.h"
#include "thread.h"
#include "thread.h"
...
@@ -524,11 +525,12 @@ public:
...
@@ -524,11 +525,12 @@ public:
}
}
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
);
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
);
// Perform a DNS lookup.
// Perform a DNS lookup.
static
Promise
<
Array
<
SocketAddress
>>
parse
(
static
Promise
<
Array
<
SocketAddress
>>
parse
(
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// TODO(someday): Allow commas in `str`.
// TODO(someday): Allow commas in `str`.
SocketAddress
result
;
SocketAddress
result
;
...
@@ -580,7 +582,8 @@ public:
...
@@ -580,7 +582,8 @@ public:
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
// Not a number. Maybe it's a service name. Fall back to DNS.
// Not a number. Maybe it's a service name. Fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
,
filter
);
}
}
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
}
else
{
}
else
{
...
@@ -622,25 +625,45 @@ public:
...
@@ -622,25 +625,45 @@ public:
switch
(
InetPtonA
(
af
,
buffer
,
addrTarget
))
{
switch
(
InetPtonA
(
af
,
buffer
,
addrTarget
))
{
case
1
:
{
case
1
:
{
// success.
// success.
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"address family blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
}
}
case
0
:
case
0
:
// It's apparently not a simple address... fall back to DNS.
// It's apparently not a simple address... fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
,
filter
);
default
:
default
:
KJ_FAIL_WIN32
(
"InetPton"
,
WSAGetLastError
(),
af
,
addrPart
);
KJ_FAIL_WIN32
(
"InetPton"
,
WSAGetLastError
(),
af
,
addrPart
);
}
}
}
}
static
SocketAddress
getLocalAddress
(
int
sockfd
)
{
static
SocketAddress
getLocalAddress
(
SOCKET
sockfd
)
{
SocketAddress
result
;
SocketAddress
result
;
result
.
addrlen
=
sizeof
(
addr
);
result
.
addrlen
=
sizeof
(
addr
);
KJ_WINSOCK
(
getsockname
(
sockfd
,
&
result
.
addr
.
generic
,
&
result
.
addrlen
));
KJ_WINSOCK
(
getsockname
(
sockfd
,
&
result
.
addr
.
generic
,
&
result
.
addrlen
));
return
result
;
return
result
;
}
}
static
SocketAddress
getPeerAddress
(
SOCKET
sockfd
)
{
SocketAddress
result
;
result
.
addrlen
=
sizeof
(
addr
);
KJ_WINSOCK
(
getpeername
(
sockfd
,
&
result
.
addr
.
generic
,
&
result
.
addrlen
));
return
result
;
}
bool
allowedBy
(
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllow
(
&
addr
.
generic
,
addrlen
);
}
bool
parseAllowedBy
(
_
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllowParse
(
&
addr
.
generic
,
addrlen
);
}
static
SocketAddress
getWildcardForFamily
(
int
family
)
{
static
SocketAddress
getWildcardForFamily
(
int
family
)
{
SocketAddress
result
;
SocketAddress
result
;
switch
(
family
)
{
switch
(
family
)
{
...
@@ -680,8 +703,9 @@ class SocketAddress::LookupReader {
...
@@ -680,8 +703,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
// getaddrinfo.
public
:
public
:
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
)
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
,
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
))
{}
_
::
NetworkFilter
&
filter
)
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
)),
filter
(
filter
)
{}
~
LookupReader
()
{
~
LookupReader
()
{
if
(
thread
)
thread
->
detach
();
if
(
thread
)
thread
->
detach
();
...
@@ -694,7 +718,7 @@ public:
...
@@ -694,7 +718,7 @@ public:
thread
=
nullptr
;
thread
=
nullptr
;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
// anyway.
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no addresses."
)
{
break
;
}
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no
permitted
addresses."
)
{
break
;
}
return
addresses
.
releaseAsArray
();
return
addresses
.
releaseAsArray
();
}
else
{
}
else
{
// getaddrinfo() can return multiple copies of the same address for several reasons.
// getaddrinfo() can return multiple copies of the same address for several reasons.
...
@@ -707,8 +731,10 @@ public:
...
@@ -707,8 +731,10 @@ public:
//
//
// So we instead resort to de-duping results.
// So we instead resort to de-duping results.
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
current
.
parseAllowedBy
(
filter
))
{
addresses
.
add
(
current
);
addresses
.
add
(
current
);
}
}
}
return
read
();
return
read
();
}
}
});
});
...
@@ -717,6 +743,7 @@ public:
...
@@ -717,6 +743,7 @@ public:
private
:
private
:
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
AsyncInputStream
>
input
;
kj
::
Own
<
AsyncInputStream
>
input
;
_
::
NetworkFilter
&
filter
;
SocketAddress
current
;
SocketAddress
current
;
kj
::
Vector
<
SocketAddress
>
addresses
;
kj
::
Vector
<
SocketAddress
>
addresses
;
std
::
set
<
SocketAddress
>
alreadySeen
;
std
::
set
<
SocketAddress
>
alreadySeen
;
...
@@ -728,7 +755,8 @@ struct SocketAddress::LookupParams {
...
@@ -728,7 +755,8 @@ struct SocketAddress::LookupParams {
};
};
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
// the only cross-platform DNS API and it is blocking.
//
//
...
@@ -818,7 +846,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -818,7 +846,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}
}));
}));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
)
,
filter
);
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
}
}
...
@@ -826,8 +854,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -826,8 +854,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFd
{
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFd
{
public
:
public
:
FdConnectionReceiver
(
Win32EventPort
&
eventPort
,
SOCKET
fd
,
uint
flags
)
FdConnectionReceiver
(
Win32EventPort
&
eventPort
,
SOCKET
fd
,
:
OwnedFd
(
fd
,
flags
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFd
(
fd
,
flags
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
.
observeIo
(
reinterpret_cast
<
HANDLE
>
(
fd
))),
observer
(
eventPort
.
observeIo
(
reinterpret_cast
<
HANDLE
>
(
fd
))),
address
(
SocketAddress
::
getLocalAddress
(
fd
))
{
address
(
SocketAddress
::
getLocalAddress
(
fd
))
{
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
...
@@ -858,8 +887,10 @@ public:
...
@@ -858,8 +887,10 @@ public:
}
}
}
}
return
op
->
onComplete
().
attach
(
kj
::
mv
(
scratch
)).
then
(
mvCapture
(
result
,
return
op
->
onComplete
().
then
(
mvCapture
(
result
,
mvCapture
(
scratch
,
[
this
](
Own
<
AsyncIoStream
>
stream
,
Win32EventPort
::
IoResult
ioResult
)
{
[
this
,
newFd
]
(
Array
<
byte
>
scratch
,
Own
<
AsyncIoStream
>
stream
,
Win32EventPort
::
IoResult
ioResult
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
if
(
ioResult
.
errorCode
!=
ERROR_SUCCESS
)
{
if
(
ioResult
.
errorCode
!=
ERROR_SUCCESS
)
{
KJ_FAIL_WIN32
(
"AcceptEx()"
,
ioResult
.
errorCode
)
{
break
;
}
KJ_FAIL_WIN32
(
"AcceptEx()"
,
ioResult
.
errorCode
)
{
break
;
}
}
else
{
}
else
{
...
@@ -867,8 +898,19 @@ public:
...
@@ -867,8 +898,19 @@ public:
stream
->
setsockopt
(
SOL_SOCKET
,
SO_UPDATE_ACCEPT_CONTEXT
,
stream
->
setsockopt
(
SOL_SOCKET
,
SO_UPDATE_ACCEPT_CONTEXT
,
reinterpret_cast
<
char
*>
(
&
me
),
sizeof
(
me
));
reinterpret_cast
<
char
*>
(
&
me
),
sizeof
(
me
));
}
}
// Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
// named `scratch`). However, the format in which it writes these is undocumented, and
// doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
// why they require the buffer to have space for it in the first place. We'll need to call
// getpeername() to get the address.
auto
addr
=
SocketAddress
::
getPeerAddress
(
newFd
);
if
(
addr
.
allowedBy
(
filter
))
{
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
}));
}
else
{
return
accept
();
}
})));
}
}
uint
getPort
()
override
{
uint
getPort
()
override
{
...
@@ -888,6 +930,7 @@ public:
...
@@ -888,6 +930,7 @@ public:
public
:
public
:
Win32EventPort
&
eventPort
;
Win32EventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Own
<
Win32EventPort
::
IoObserver
>
observer
;
Own
<
Win32EventPort
::
IoObserver
>
observer
;
LPFN_ACCEPTEX
acceptEx
=
nullptr
;
LPFN_ACCEPTEX
acceptEx
=
nullptr
;
SocketAddress
address
;
SocketAddress
address
;
...
@@ -923,8 +966,9 @@ public:
...
@@ -923,8 +966,9 @@ public:
return
kj
::
mv
(
result
);
return
kj
::
mv
(
result
);
}));
}));
}
}
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
SOCKET
fd
,
uint
flags
=
0
)
override
{
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
flags
);
SOCKET
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
filter
,
flags
);
}
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
...
@@ -941,12 +985,14 @@ private:
...
@@ -941,12 +985,14 @@ private:
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
public
:
public
:
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
Array
<
SocketAddress
>
addrs
)
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
:
lowLevel
(
lowLevel
),
addrs
(
kj
::
mv
(
addrs
))
{}
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
Array
<
SocketAddress
>
addrs
)
:
lowLevel
(
lowLevel
),
filter
(
filter
),
addrs
(
kj
::
mv
(
addrs
))
{}
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
promise
=
connectImpl
(
lowLevel
,
addrsCopy
);
auto
promise
=
connectImpl
(
lowLevel
,
filter
,
addrsCopy
);
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
}
}
...
@@ -974,7 +1020,7 @@ public:
...
@@ -974,7 +1020,7 @@ public:
KJ_WINSOCK
(
::
listen
(
fd
,
SOMAXCONN
));
KJ_WINSOCK
(
::
listen
(
fd
,
SOMAXCONN
));
}
}
return
lowLevel
.
wrapListenSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapListenSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
...
@@ -998,11 +1044,11 @@ public:
...
@@ -998,11 +1044,11 @@ public:
addrs
[
0
].
bind
(
fd
);
addrs
[
0
].
bind
(
fd
);
}
}
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
NetworkAddress
>
clone
()
override
{
Own
<
NetworkAddress
>
clone
()
override
{
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
kj
::
heapArray
(
addrs
.
asPtr
()));
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
heapArray
(
addrs
.
asPtr
()));
}
}
String
toString
()
override
{
String
toString
()
override
{
...
@@ -1016,26 +1062,34 @@ public:
...
@@ -1016,26 +1062,34 @@ public:
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Array
<
SocketAddress
>
addrs
;
Array
<
SocketAddress
>
addrs
;
uint
counter
=
0
;
uint
counter
=
0
;
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
KJ_ASSERT
(
addrs
.
size
()
>
0
);
KJ_ASSERT
(
addrs
.
size
()
>
0
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
return
kj
::
evalNow
([
&
]()
{
return
kj
::
evalNow
([
&
]()
->
Promise
<
Own
<
AsyncIoStream
>>
{
if
(
!
addrs
[
0
].
allowedBy
(
filter
))
{
return
KJ_EXCEPTION
(
FAILED
,
"connect() blocked by restrictPeers()"
);
}
else
{
return
lowLevel
.
wrapConnectingSocketFd
(
return
lowLevel
.
wrapConnectingSocketFd
(
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
}
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Success, pass along.
// Success, pass along.
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
},
[
&
lowLevel
,
KJ_CPCAP
(
addrs
)](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
},
[
&
lowLevel
,
&
filter
,
KJ_CPCAP
(
addrs
)](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Connect failed.
// Connect failed.
if
(
addrs
.
size
()
>
1
)
{
if
(
addrs
.
size
()
>
1
)
{
// Try the next address instead.
// Try the next address instead.
return
connectImpl
(
lowLevel
,
addrs
.
slice
(
1
,
addrs
.
size
()));
return
connectImpl
(
lowLevel
,
filter
,
addrs
.
slice
(
1
,
addrs
.
size
()));
}
else
{
}
else
{
// No more addresses to try, so propagate the exception.
// No more addresses to try, so propagate the exception.
return
kj
::
mv
(
exception
);
return
kj
::
mv
(
exception
);
...
@@ -1047,25 +1101,35 @@ private:
...
@@ -1047,25 +1101,35 @@ private:
class
SocketNetwork
final
:
public
Network
{
class
SocketNetwork
final
:
public
Network
{
public
:
public
:
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
SocketNetwork
&
parent
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
)
:
lowLevel
(
parent
.
lowLevel
),
filter
(
allow
,
deny
,
parent
.
filter
)
{}
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
auto
&
lowLevelCopy
=
lowLevel
;
return
evalLater
(
mvCapture
(
heapString
(
addr
),
[
this
,
portHint
](
String
&&
addr
)
{
return
evalLater
(
mvCapture
(
heapString
(
addr
),
return
SocketAddress
::
parse
(
lowLevel
,
addr
,
portHint
,
filter
);
[
&
lowLevelCopy
,
portHint
](
String
&&
addr
)
{
})).
then
([
this
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
SocketAddress
::
parse
(
lowLevelCopy
,
addr
,
portHint
);
return
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
mv
(
addresses
));
})).
then
([
&
lowLevelCopy
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
heap
<
NetworkAddressImpl
>
(
lowLevelCopy
,
kj
::
mv
(
addresses
));
});
});
}
}
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
array
.
finish
()));
KJ_REQUIRE
(
array
[
0
].
allowedBy
(
filter
),
"address blocked by restrictPeers()"
)
{
break
;
}
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
array
.
finish
()));
}
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
override
{
return
heap
<
SocketNetwork
>
(
*
this
,
allow
,
deny
);
}
}
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
_
::
NetworkFilter
filter
;
};
};
// =======================================================================================
// =======================================================================================
...
...
c++/src/kj/async-io.c++
View file @
ac6b5d30
...
@@ -19,10 +19,30 @@
...
@@ -19,10 +19,30 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "debug.h"
#include "debug.h"
#include "vector.h"
#include "vector.h"
#if _WIN32
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include "windows-sanity.h"
#define inet_pton InetPtonA
#define inet_ntop InetNtopA
#else
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/un.h>
#endif
namespace
kj
{
namespace
kj
{
Promise
<
void
>
AsyncInputStream
::
read
(
void
*
buffer
,
size_t
bytes
)
{
Promise
<
void
>
AsyncInputStream
::
read
(
void
*
buffer
,
size_t
bytes
)
{
...
@@ -188,8 +208,351 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
...
@@ -188,8 +208,351 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
Own
<
DatagramPort
>
NetworkAddress
::
bindDatagramPort
()
{
Own
<
DatagramPort
>
NetworkAddress
::
bindDatagramPort
()
{
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
}
}
Own
<
DatagramPort
>
LowLevelAsyncIoProvider
::
wrapDatagramSocketFd
(
Fd
fd
,
uint
flags
)
{
Own
<
DatagramPort
>
LowLevelAsyncIoProvider
::
wrapDatagramSocketFd
(
Fd
fd
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
{
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
}
}
// =======================================================================================
namespace
_
{
// private
#if !_WIN32
kj
::
ArrayPtr
<
const
char
>
safeUnixPath
(
const
struct
sockaddr_un
*
addr
,
uint
addrlen
)
{
KJ_REQUIRE
(
addr
->
sun_family
==
AF_UNIX
,
"not a unix address"
);
KJ_REQUIRE
(
addrlen
>=
offsetof
(
sockaddr_un
,
sun_path
),
"invalid unix address"
);
size_t
maxPathlen
=
addrlen
-
offsetof
(
sockaddr_un
,
sun_path
);
size_t
pathlen
;
if
(
maxPathlen
>
0
&&
addr
->
sun_path
[
0
]
==
'\0'
)
{
// Linux "abstract" unix address
pathlen
=
strnlen
(
addr
->
sun_path
+
1
,
maxPathlen
-
1
)
+
1
;
}
else
{
pathlen
=
strnlen
(
addr
->
sun_path
,
maxPathlen
);
}
return
kj
::
arrayPtr
(
addr
->
sun_path
,
pathlen
);
}
#endif // !_WIN32
CidrRange
::
CidrRange
(
StringPtr
pattern
)
{
size_t
slashPos
=
KJ_REQUIRE_NONNULL
(
pattern
.
findFirst
(
'/'
),
"invalid CIDR"
,
pattern
);
bitCount
=
pattern
.
slice
(
slashPos
+
1
).
parseAs
<
uint
>
();
KJ_STACK_ARRAY
(
char
,
addr
,
slashPos
+
1
,
128
,
128
);
memcpy
(
addr
.
begin
(),
pattern
.
begin
(),
slashPos
);
addr
[
slashPos
]
=
'\0'
;
if
(
pattern
.
findFirst
(
':'
)
==
nullptr
)
{
family
=
AF_INET
;
KJ_REQUIRE
(
bitCount
<=
32
,
"invalid CIDR"
,
pattern
);
}
else
{
family
=
AF_INET6
;
KJ_REQUIRE
(
bitCount
<=
128
,
"invalid CIDR"
,
pattern
);
}
KJ_ASSERT
(
inet_pton
(
family
,
addr
.
begin
(),
bits
)
>
0
,
"invalid CIDR"
,
pattern
);
zeroIrrelevantBits
();
}
CidrRange
::
CidrRange
(
int
family
,
ArrayPtr
<
const
byte
>
bits
,
uint
bitCount
)
:
family
(
family
),
bitCount
(
bitCount
)
{
if
(
family
==
AF_INET
)
{
KJ_REQUIRE
(
bitCount
<=
32
);
}
else
{
KJ_REQUIRE
(
bitCount
<=
128
);
}
KJ_REQUIRE
(
bits
.
size
()
*
8
>=
bitCount
);
size_t
byteCount
=
(
bitCount
+
7
)
/
8
;
memcpy
(
this
->
bits
,
bits
.
begin
(),
byteCount
);
memset
(
this
->
bits
+
byteCount
,
0
,
sizeof
(
this
->
bits
)
-
byteCount
);
zeroIrrelevantBits
();
}
CidrRange
CidrRange
::
inet4
(
ArrayPtr
<
const
byte
>
bits
,
uint
bitCount
)
{
return
CidrRange
(
AF_INET
,
bits
,
bitCount
);
}
CidrRange
CidrRange
::
inet6
(
ArrayPtr
<
const
uint16_t
>
prefix
,
ArrayPtr
<
const
uint16_t
>
suffix
,
uint
bitCount
)
{
KJ_REQUIRE
(
prefix
.
size
()
+
suffix
.
size
()
<=
8
);
byte
bits
[
16
]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
};
for
(
size_t
i
:
kj
::
indices
(
prefix
))
{
bits
[
i
*
2
]
=
prefix
[
i
]
>>
8
;
bits
[
i
*
2
+
1
]
=
prefix
[
i
]
&
0xff
;
}
byte
*
suffixBits
=
bits
+
(
16
-
suffix
.
size
()
*
2
);
for
(
size_t
i
:
kj
::
indices
(
suffix
))
{
suffixBits
[
i
*
2
]
=
suffix
[
i
]
>>
8
;
suffixBits
[
i
*
2
+
1
]
=
suffix
[
i
]
&
0xff
;
}
return
CidrRange
(
AF_INET6
,
bits
,
bitCount
);
}
bool
CidrRange
::
matches
(
const
struct
sockaddr
*
addr
)
const
{
const
byte
*
otherBits
;
switch
(
family
)
{
case
AF_INET
:
if
(
addr
->
sa_family
==
AF_INET6
)
{
otherBits
=
reinterpret_cast
<
const
struct
sockaddr_in6
*>
(
addr
)
->
sin6_addr
.
s6_addr
;
static
constexpr
byte
V6MAPPED
[
12
]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0xff
,
0xff
};
if
(
memcmp
(
otherBits
,
V6MAPPED
,
sizeof
(
V6MAPPED
))
==
0
)
{
// We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning
// it's equivalent to an ipv4 address. Try to match against the ipv4 part.
otherBits
=
otherBits
+
sizeof
(
V6MAPPED
);
}
else
{
return
false
;
}
}
else
if
(
addr
->
sa_family
==
AF_INET
)
{
otherBits
=
reinterpret_cast
<
const
byte
*>
(
&
reinterpret_cast
<
const
struct
sockaddr_in
*>
(
addr
)
->
sin_addr
.
s_addr
);
}
else
{
return
false
;
}
break
;
case
AF_INET6
:
if
(
addr
->
sa_family
!=
AF_INET6
)
return
false
;
otherBits
=
reinterpret_cast
<
const
struct
sockaddr_in6
*>
(
addr
)
->
sin6_addr
.
s6_addr
;
break
;
default
:
KJ_UNREACHABLE
;
}
if
(
memcmp
(
bits
,
otherBits
,
bitCount
/
8
)
!=
0
)
return
false
;
return
bitCount
==
128
||
bits
[
bitCount
/
8
]
==
(
otherBits
[
bitCount
/
8
]
&
(
0xff00
>>
(
bitCount
%
8
)));
}
bool
CidrRange
::
matchesFamily
(
int
family
)
const
{
switch
(
family
)
{
case
AF_INET
:
return
this
->
family
==
AF_INET
;
case
AF_INET6
:
// Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range.
return
true
;
default
:
return
false
;
}
}
String
CidrRange
::
toString
()
const
{
char
result
[
128
];
KJ_ASSERT
(
inet_ntop
(
family
,
(
void
*
)
bits
,
result
,
sizeof
(
result
))
==
result
);
return
kj
::
str
(
result
,
'/'
,
bitCount
);
}
void
CidrRange
::
zeroIrrelevantBits
()
{
// Mask out insignificant bits of partial byte.
if
(
bitCount
<
128
)
{
bits
[
bitCount
/
8
]
&=
0xff00
>>
(
bitCount
%
8
);
// Zero the remaining bytes.
size_t
n
=
bitCount
/
8
+
1
;
memset
(
bits
+
n
,
0
,
sizeof
(
bits
)
-
n
);
}
}
// -----------------------------------------------------------------------------
ArrayPtr
<
const
CidrRange
>
localCidrs
()
{
static
const
CidrRange
result
[]
=
{
// localhost
"127.0.0.0/8"
_kj
,
"::1/128"
_kj
,
// Trying to *connect* to 0.0.0.0 on many systems is equivalent to connecting to localhost.
// (wat)
"0.0.0.0/32"
_kj
,
"::/128"
_kj
,
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return
kj
::
arrayPtr
(
result
,
kj
::
size
(
result
));
}
ArrayPtr
<
const
CidrRange
>
privateCidrs
()
{
static
const
CidrRange
result
[]
=
{
"10.0.0.0/8"
_kj
,
// RFC1918 reserved for internal network
"100.64.0.0/10"
_kj
,
// RFC6598 "shared address space" for carrier-grade NAT
"169.254.0.0/16"
_kj
,
// RFC3927 "link local" (auto-configured LAN in absence of DHCP)
"172.16.0.0/12"
_kj
,
// RFC1918 reserved for internal network
"192.168.0.0/16"
_kj
,
// RFC1918 reserved for internal network
"fc00::/7"
_kj
,
// RFC4193 unique private network
"fe80::/10"
_kj
,
// RFC4291 "link local" (auto-configured LAN in absence of DHCP)
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return
kj
::
arrayPtr
(
result
,
kj
::
size
(
result
));
}
ArrayPtr
<
const
CidrRange
>
reservedCidrs
()
{
static
const
CidrRange
result
[]
=
{
"192.0.0.0/24"
_kj
,
// RFC6890 reserved for special protocols
"224.0.0.0/4"
_kj
,
// RFC1112 multicast
"240.0.0.0/4"
_kj
,
// RFC1112 multicast / reserved for future use
"255.255.255.255/32"
_kj
,
// RFC0919 broadcast address
"2001::/23"
_kj
,
// RFC2928 reserved for special protocols
"ff00::/8"
_kj
,
// RFC4291 multicast
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return
kj
::
arrayPtr
(
result
,
kj
::
size
(
result
));
}
ArrayPtr
<
const
CidrRange
>
exampleAddresses
()
{
static
const
CidrRange
result
[]
=
{
"192.0.2.0/24"
_kj
,
// RFC5737 "example address" block 1 -- like example.com for IPs
"198.51.100.0/24"
_kj
,
// RFC5737 "example address" block 2 -- like example.com for IPs
"203.0.113.0/24"
_kj
,
// RFC5737 "example address" block 3 -- like example.com for IPs
"2001:db8::/32"
_kj
,
// RFC3849 "example address" block -- like example.com for IPs
};
// TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly
// casting to our return type.
return
kj
::
arrayPtr
(
result
,
kj
::
size
(
result
));
}
NetworkFilter
::
NetworkFilter
()
:
allowUnix
(
true
),
allowAbstractUnix
(
true
)
{
allowCidrs
.
add
(
CidrRange
::
inet4
({
0
,
0
,
0
,
0
},
0
));
allowCidrs
.
add
(
CidrRange
::
inet6
({},
{},
0
));
denyCidrs
.
addAll
(
reservedCidrs
());
}
NetworkFilter
::
NetworkFilter
(
ArrayPtr
<
const
StringPtr
>
allow
,
ArrayPtr
<
const
StringPtr
>
deny
,
NetworkFilter
&
next
)
:
allowUnix
(
false
),
allowAbstractUnix
(
false
),
next
(
next
)
{
for
(
auto
rule
:
allow
)
{
if
(
rule
==
"local"
)
{
allowCidrs
.
addAll
(
localCidrs
());
}
else
if
(
rule
==
"network"
)
{
allowCidrs
.
add
(
CidrRange
::
inet4
({
0
,
0
,
0
,
0
},
0
));
allowCidrs
.
add
(
CidrRange
::
inet6
({},
{},
0
));
denyCidrs
.
addAll
(
localCidrs
());
}
else
if
(
rule
==
"private"
)
{
allowCidrs
.
addAll
(
privateCidrs
());
allowCidrs
.
addAll
(
localCidrs
());
}
else
if
(
rule
==
"public"
)
{
allowCidrs
.
add
(
CidrRange
::
inet4
({
0
,
0
,
0
,
0
},
0
));
allowCidrs
.
add
(
CidrRange
::
inet6
({},
{},
0
));
denyCidrs
.
addAll
(
privateCidrs
());
denyCidrs
.
addAll
(
localCidrs
());
}
else
if
(
rule
==
"unix"
)
{
allowUnix
=
true
;
}
else
if
(
rule
==
"unix-abstract"
)
{
allowAbstractUnix
=
true
;
}
else
{
allowCidrs
.
add
(
CidrRange
(
rule
));
}
}
for
(
auto
rule
:
deny
)
{
if
(
rule
==
"local"
)
{
denyCidrs
.
addAll
(
localCidrs
());
}
else
if
(
rule
==
"network"
)
{
KJ_FAIL_REQUIRE
(
"don't deny 'network', allow 'local' instead"
);
}
else
if
(
rule
==
"private"
)
{
denyCidrs
.
addAll
(
privateCidrs
());
}
else
if
(
rule
==
"public"
)
{
// Tricky: What if we allow 'network' and deny 'public'?
KJ_FAIL_REQUIRE
(
"don't deny 'public', allow 'private' instead"
);
}
else
if
(
rule
==
"unix"
)
{
allowUnix
=
false
;
}
else
if
(
rule
==
"unix-abstract"
)
{
allowAbstractUnix
=
false
;
}
else
{
denyCidrs
.
add
(
CidrRange
(
rule
));
}
}
}
bool
NetworkFilter
::
shouldAllow
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
{
KJ_REQUIRE
(
addrlen
>=
sizeof
(
addr
->
sa_family
));
#if !_WIN32
if
(
addr
->
sa_family
==
AF_UNIX
)
{
auto
path
=
safeUnixPath
(
reinterpret_cast
<
const
struct
sockaddr_un
*>
(
addr
),
addrlen
);
if
(
path
.
size
()
>
0
&&
path
[
0
]
==
'\0'
)
{
return
allowAbstractUnix
;
}
else
{
return
allowUnix
;
}
}
#endif
bool
allowed
=
false
;
uint
allowSpecificity
=
0
;
for
(
auto
&
cidr
:
allowCidrs
)
{
if
(
cidr
.
matches
(
addr
))
{
allowSpecificity
=
kj
::
max
(
allowSpecificity
,
cidr
.
getSpecificity
());
allowed
=
true
;
}
}
if
(
!
allowed
)
return
false
;
for
(
auto
&
cidr
:
denyCidrs
)
{
if
(
cidr
.
matches
(
addr
))
{
if
(
cidr
.
getSpecificity
()
>=
allowSpecificity
)
return
false
;
}
}
KJ_IF_MAYBE
(
n
,
next
)
{
return
n
->
shouldAllow
(
addr
,
addrlen
);
}
else
{
return
true
;
}
}
bool
NetworkFilter
::
shouldAllowParse
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
{
bool
matched
=
false
;
#if !_WIN32
if
(
addr
->
sa_family
==
AF_UNIX
)
{
auto
path
=
safeUnixPath
(
reinterpret_cast
<
const
struct
sockaddr_un
*>
(
addr
),
addrlen
);
if
(
path
.
size
()
>
0
&&
path
[
0
]
==
'\0'
)
{
if
(
allowAbstractUnix
)
matched
=
true
;
}
else
{
if
(
allowUnix
)
matched
=
true
;
}
}
else
{
#endif
for
(
auto
&
cidr
:
allowCidrs
)
{
if
(
cidr
.
matchesFamily
(
addr
->
sa_family
))
{
matched
=
true
;
}
}
#if !_WIN32
}
#endif
if
(
matched
)
{
KJ_IF_MAYBE
(
n
,
next
)
{
return
n
->
shouldAllowParse
(
addr
,
addrlen
);
}
else
{
return
true
;
}
}
else
{
// No allow rule matches this address family, so don't even allow parsing it.
return
false
;
}
}
}
// namespace _ (private)
}
// namespace kj
}
// namespace kj
c++/src/kj/async-io.h
View file @
ac6b5d30
...
@@ -319,6 +319,67 @@ public:
...
@@ -319,6 +319,67 @@ public:
virtual
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
=
0
;
virtual
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
=
0
;
// Construct a network address from a legacy struct sockaddr.
// Construct a network address from a legacy struct sockaddr.
virtual
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
KJ_WARN_UNUSED_RESULT
=
0
;
// Constructs a new Network instance wrapping this one which restricts which peer addresses are
// permitted (both for outgoing and incoming connections).
//
// Communication will be allowed only with peers whose addresses match one of the patterns
// specified in the `allow` array. If a `deny` array is specified, then any address which matches
// a pattern in `deny` and *does not* match any more-specific pattern in `allow` will also be
// denied.
//
// The syntax of address patterns depends on the network, except that three special patterns are
// defined for all networks:
// - "private": Matches network addresses that are reserved by standards for private networks,
// such as "10.0.0.0/8" or "192.168.0.0/16". This is a superset of "local".
// - "public": Opposite of "private".
// - "local": Matches network addresses that are defined by standards to only be accessible from
// the local machine, such as "127.0.0.0/8" or Unix domain addresses.
// - "network": Opposite of "local".
//
// For the standard KJ network implementation, the following patterns are also recognized:
// - Network blocks specified in CIDR notation (ipv4 and ipv6), such as "192.0.2.0/24" or
// "2001:db8::/32".
// - "unix" to match all Unix domain addresses. (In the future, we may support specifying a
// glob.)
// - "unix-abstract" to match Linux's "abstract unix domain" addresses. (In the future, we may
// support specifying a glob.)
//
// Network restrictions apply *after* DNS resolution (otherwise they'd be useless).
//
// It is legal to parseAddress() a restricted address. An exception won't be thrown until
// connect() is called.
//
// It's possible to listen() on a restricted address. However, connections will only be accepted
// from non-restricted addresses; others will be dropped. If a particular listen address has no
// valid peers (e.g. because it's a unix socket address and unix sockets are not allowed) then
// listen() may throw (or may simply never receive any connections).
//
// Examples:
//
// auto restricted = network->restrictPeers({"public"});
//
// Allows connections only to/from public internet addresses. Use this when connecting to an
// address specified by a third party that is not trusted and is not themselves already on your
// private network.
//
// auto restricted = network->restrictPeers({"private"});
//
// Allows connections only to/from the private network. Use this on the server side to reject
// connections from the public internet.
//
// auto restricted = network->restrictPeers({"192.0.2.0/24"}, {"192.0.2.3/32"});
//
// Allows connections only to/from 192.0.2.*, except 192.0.2.3 which is blocked.
//
// auto restricted = network->restrictPeers({"10.0.0.0/8", "10.1.2.3/32"}, {"10.1.2.0/24"});
//
// Allows connections to/from 10.*.*.*, with the exception of 10.1.2.* (which is denied), with an
// exception to the exception of 10.1.2.3 (which is allowed, because it is matched by an allow
// rule that is more specific than the deny rule).
};
};
// =======================================================================================
// =======================================================================================
...
@@ -470,13 +531,21 @@ public:
...
@@ -470,13 +531,21 @@ public:
//
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
Fd
fd
,
uint
flags
=
0
)
=
0
;
class
NetworkFilter
{
public
:
virtual
bool
shouldAllow
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
=
0
;
// Returns true if incoming connections or datagrams from the given peer should be accepted.
// If false, they will be dropped. This is used to implement kj::Network::restrictPeers().
};
virtual
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
Fd
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
=
0
;
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// Create an AsyncIoStream wrapping a listen socket file descriptor. This socket should already
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
// have had `bind()` and `listen()` called on it, so it's ready for `accept()`.
//
//
// `flags` is a bitwise-OR of the values of the `Flags` enum.
// `flags` is a bitwise-OR of the values of the `Flags` enum.
virtual
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
Fd
fd
,
uint
flags
=
0
);
virtual
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
Fd
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
);
virtual
Timer
&
getTimer
()
=
0
;
virtual
Timer
&
getTimer
()
=
0
;
// Returns a `Timer` based on real time. Time does not pass while event handlers are running --
// Returns a `Timer` based on real time. Time does not pass while event handlers are running --
...
...
c++/src/kj/common.h
View file @
ac6b5d30
...
@@ -1280,7 +1280,7 @@ public:
...
@@ -1280,7 +1280,7 @@ public:
return
ArrayPtr
<
const
T
>
(
ptr
,
size_
);
return
ArrayPtr
<
const
T
>
(
ptr
,
size_
);
}
}
inline
size_t
size
()
const
{
return
size_
;
}
inline
constexpr
size_t
size
()
const
{
return
size_
;
}
inline
const
T
&
operator
[](
size_t
index
)
const
{
inline
const
T
&
operator
[](
size_t
index
)
const
{
KJ_IREQUIRE
(
index
<
size_
,
"Out-of-bounds ArrayPtr access."
);
KJ_IREQUIRE
(
index
<
size_
,
"Out-of-bounds ArrayPtr access."
);
return
ptr
[
index
];
return
ptr
[
index
];
...
@@ -1294,8 +1294,8 @@ public:
...
@@ -1294,8 +1294,8 @@ public:
inline
T
*
end
()
{
return
ptr
+
size_
;
}
inline
T
*
end
()
{
return
ptr
+
size_
;
}
inline
T
&
front
()
{
return
*
ptr
;
}
inline
T
&
front
()
{
return
*
ptr
;
}
inline
T
&
back
()
{
return
*
(
ptr
+
size_
-
1
);
}
inline
T
&
back
()
{
return
*
(
ptr
+
size_
-
1
);
}
inline
const
T
*
begin
()
const
{
return
ptr
;
}
inline
const
expr
const
T
*
begin
()
const
{
return
ptr
;
}
inline
const
T
*
end
()
const
{
return
ptr
+
size_
;
}
inline
const
expr
const
T
*
end
()
const
{
return
ptr
+
size_
;
}
inline
const
T
&
front
()
const
{
return
*
ptr
;
}
inline
const
T
&
front
()
const
{
return
*
ptr
;
}
inline
const
T
&
back
()
const
{
return
*
(
ptr
+
size_
-
1
);
}
inline
const
T
&
back
()
const
{
return
*
(
ptr
+
size_
-
1
);
}
...
...
c++/src/kj/compat/tls.c++
View file @
ac6b5d30
...
@@ -443,6 +443,8 @@ private:
...
@@ -443,6 +443,8 @@ private:
class
TlsNetwork
:
public
kj
::
Network
{
class
TlsNetwork
:
public
kj
::
Network
{
public
:
public
:
TlsNetwork
(
TlsContext
&
tls
,
kj
::
Network
&
inner
)
:
tls
(
tls
),
inner
(
inner
)
{}
TlsNetwork
(
TlsContext
&
tls
,
kj
::
Network
&
inner
)
:
tls
(
tls
),
inner
(
inner
)
{}
TlsNetwork
(
TlsContext
&
tls
,
kj
::
Own
<
kj
::
Network
>
inner
)
:
tls
(
tls
),
inner
(
*
inner
),
ownInner
(
kj
::
mv
(
inner
))
{}
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
)
override
{
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
)
override
{
kj
::
String
hostname
;
kj
::
String
hostname
;
...
@@ -463,9 +465,19 @@ public:
...
@@ -463,9 +465,19 @@ public:
KJ_UNIMPLEMENTED
(
"TLS does not implement getSockaddr() because it needs to know hostnames"
);
KJ_UNIMPLEMENTED
(
"TLS does not implement getSockaddr() because it needs to know hostnames"
);
}
}
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
override
{
// TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
// Or is it better to let people do that via the TlsContext? A neat thing about
// restrictPeers() is that it's easy to make user-configurable.
return
kj
::
heap
<
TlsNetwork
>
(
tls
,
inner
.
restrictPeers
(
allow
,
deny
));
}
private
:
private
:
TlsContext
&
tls
;
TlsContext
&
tls
;
kj
::
Network
&
inner
;
kj
::
Network
&
inner
;
kj
::
Own
<
kj
::
Network
>
ownInner
;
};
};
}
// namespace
}
// namespace
...
...
c++/src/kj/string-test.c++
View file @
ac6b5d30
...
@@ -173,6 +173,18 @@ TEST(String, ToString) {
...
@@ -173,6 +173,18 @@ TEST(String, ToString) {
}
}
#endif
#endif
KJ_TEST
(
"string literals with _kj suffix"
)
{
static
constexpr
StringPtr
FOO
=
"foo"
_kj
;
KJ_EXPECT
(
FOO
==
"foo"
,
FOO
);
KJ_EXPECT
(
FOO
[
3
]
==
0
);
KJ_EXPECT
(
"foo
\0
bar"
_kj
==
StringPtr
(
"foo
\0
bar"
,
7
));
static
constexpr
ArrayPtr
<
const
char
>
ARR
=
"foo"
_kj
;
KJ_EXPECT
(
ARR
.
size
()
==
3
);
KJ_EXPECT
(
kj
::
str
(
ARR
)
==
"foo"
);
}
}
// namespace
}
// namespace
}
// namespace _ (private)
}
// namespace _ (private)
}
// namespace kj
}
// namespace kj
c++/src/kj/string.h
View file @
ac6b5d30
...
@@ -31,11 +31,29 @@
...
@@ -31,11 +31,29 @@
#include <string.h>
#include <string.h>
namespace
kj
{
namespace
kj
{
class
StringPtr
;
class
String
;
class
StringPtr
;
class
StringTree
;
// string-tree.h
class
String
;
}
class
StringTree
;
// string-tree.h
constexpr
kj
::
StringPtr
operator
""
_kj
(
const
char
*
str
,
size_t
n
);
// You can append _kj to a string literal to make its type be StringPtr. There are a few cases
// where you must do this for correctness:
// - When you want to declare a constexpr StringPtr. Without _kj, this is a compile error.
// - When you want to initialize a static/global StringPtr from a string literal without forcing
// global constructor code to run at dynamic initialization time.
// - When you have a string literal that contains NUL characters. Without _kj, the string will
// be considered to end at the first NUL.
// - When you want to initialize an ArrayPtr<const char> from a string literal, without including
// the NUL terminator in the data. (Initializing an ArrayPtr from a regular string literal is
// a compile error specifically due to this ambiguity.)
//
// In other cases, there should be no difference between initializing a StringPtr from a regular
// string literal vs. one with _kj (assuming the compiler is able to optimize away strlen() on a
// string literal).
namespace
kj
{
// Our STL string SFINAE trick does not work with GCC 4.7, but it works with Clang and GCC 4.8, so
// Our STL string SFINAE trick does not work with GCC 4.7, but it works with Clang and GCC 4.8, so
// we'll just preprocess it out if not supported.
// we'll just preprocess it out if not supported.
...
@@ -75,8 +93,8 @@ public:
...
@@ -75,8 +93,8 @@ public:
// those who don't want it.
// those who don't want it.
#endif
#endif
inline
operator
ArrayPtr
<
const
char
>
()
const
;
inline
constexpr
operator
ArrayPtr
<
const
char
>
()
const
;
inline
ArrayPtr
<
const
char
>
asArray
()
const
;
inline
constexpr
ArrayPtr
<
const
char
>
asArray
()
const
;
inline
ArrayPtr
<
const
byte
>
asBytes
()
const
{
return
asArray
().
asBytes
();
}
inline
ArrayPtr
<
const
byte
>
asBytes
()
const
{
return
asArray
().
asBytes
();
}
// Result does not include NUL terminator.
// Result does not include NUL terminator.
...
@@ -121,9 +139,11 @@ public:
...
@@ -121,9 +139,11 @@ public:
// Overflowed floating numbers return inf.
// Overflowed floating numbers return inf.
private
:
private
:
inline
StringPtr
(
ArrayPtr
<
const
char
>
content
)
:
content
(
content
)
{}
inline
constexpr
StringPtr
(
ArrayPtr
<
const
char
>
content
)
:
content
(
content
)
{}
ArrayPtr
<
const
char
>
content
;
ArrayPtr
<
const
char
>
content
;
friend
constexpr
kj
::
StringPtr
(
::
operator
""
_kj
)(
const
char
*
str
,
size_t
n
);
};
};
inline
bool
operator
==
(
const
char
*
a
,
const
StringPtr
&
b
)
{
return
b
==
a
;
}
inline
bool
operator
==
(
const
char
*
a
,
const
StringPtr
&
b
)
{
return
b
==
a
;
}
...
@@ -427,12 +447,12 @@ inline String Stringifier::operator*(const Array<T>& arr) const {
...
@@ -427,12 +447,12 @@ inline String Stringifier::operator*(const Array<T>& arr) const {
inline
StringPtr
::
StringPtr
(
const
String
&
value
)
:
content
(
value
.
begin
(),
value
.
size
()
+
1
)
{}
inline
StringPtr
::
StringPtr
(
const
String
&
value
)
:
content
(
value
.
begin
(),
value
.
size
()
+
1
)
{}
inline
StringPtr
::
operator
ArrayPtr
<
const
char
>
()
const
{
inline
constexpr
StringPtr
::
operator
ArrayPtr
<
const
char
>
()
const
{
return
content
.
slice
(
0
,
content
.
size
()
-
1
);
return
ArrayPtr
<
const
char
>
(
content
.
begin
()
,
content
.
size
()
-
1
);
}
}
inline
ArrayPtr
<
const
char
>
StringPtr
::
asArray
()
const
{
inline
constexpr
ArrayPtr
<
const
char
>
StringPtr
::
asArray
()
const
{
return
content
.
slice
(
0
,
content
.
size
()
-
1
);
return
ArrayPtr
<
const
char
>
(
content
.
begin
()
,
content
.
size
()
-
1
);
}
}
inline
bool
StringPtr
::
operator
==
(
const
StringPtr
&
other
)
const
{
inline
bool
StringPtr
::
operator
==
(
const
StringPtr
&
other
)
const
{
...
@@ -531,4 +551,8 @@ inline String heapString(ArrayPtr<const char> value) {
...
@@ -531,4 +551,8 @@ inline String heapString(ArrayPtr<const char> value) {
}
// namespace kj
}
// namespace kj
constexpr
kj
::
StringPtr
operator
""
_kj
(
const
char
*
str
,
size_t
n
)
{
return
kj
::
StringPtr
(
kj
::
ArrayPtr
<
const
char
>
(
str
,
n
+
1
));
};
#endif // KJ_STRING_H_
#endif // KJ_STRING_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment