Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
C
capnproto
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
capnproto
Commits
505e71f7
Commit
505e71f7
authored
Sep 12, 2017
by
Kenton Varda
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Wire up restrictPeers() implementation.
parent
05d0a7ed
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
233 additions
and
85 deletions
+233
-85
ez-rpc.c++
c++/src/capnp/ez-rpc.c++
+15
-1
async-io-unix.c++
c++/src/kj/async-io-unix.c++
+119
-43
async-io-win32.c++
c++/src/kj/async-io-win32.c++
+91
-37
async-io.c++
c++/src/kj/async-io.c++
+8
-4
No files found.
c++/src/capnp/ez-rpc.c++
View file @
505e71f7
...
@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
...
@@ -198,6 +198,19 @@ kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
// =======================================================================================
// =======================================================================================
namespace
{
class
DummyFilter
:
public
kj
::
LowLevelAsyncIoProvider
::
NetworkFilter
{
public
:
bool
shouldAllow
(
const
struct
sockaddr
*
addr
,
uint
addrlen
)
override
{
return
true
;
}
};
static
DummyFilter
DUMMY_FILTER
;
}
// namespace
struct
EzRpcServer
::
Impl
final
:
public
SturdyRefRestorer
<
AnyPointer
>
,
struct
EzRpcServer
::
Impl
final
:
public
SturdyRefRestorer
<
AnyPointer
>
,
public
kj
::
TaskSet
::
ErrorHandler
{
public
kj
::
TaskSet
::
ErrorHandler
{
Capability
::
Client
mainInterface
;
Capability
::
Client
mainInterface
;
...
@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
...
@@ -271,7 +284,8 @@ struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
context
(
EzRpcContext
::
getThreadLocal
()),
context
(
EzRpcContext
::
getThreadLocal
()),
portPromise
(
kj
::
Promise
<
uint
>
(
port
).
fork
()),
portPromise
(
kj
::
Promise
<
uint
>
(
port
).
fork
()),
tasks
(
*
this
)
{
tasks
(
*
this
)
{
acceptLoop
(
context
->
getLowLevelIoProvider
().
wrapListenSocketFd
(
socketFd
),
readerOpts
);
acceptLoop
(
context
->
getLowLevelIoProvider
().
wrapListenSocketFd
(
socketFd
,
DUMMY_FILTER
),
readerOpts
);
}
}
void
acceptLoop
(
kj
::
Own
<
kj
::
ConnectionReceiver
>&&
listener
,
ReaderOptions
readerOpts
)
{
void
acceptLoop
(
kj
::
Own
<
kj
::
ConnectionReceiver
>&&
listener
,
ReaderOptions
readerOpts
)
{
...
...
c++/src/kj/async-io-unix.c++
View file @
505e71f7
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
// For Win32 implementation, see async-io-win32.c++.
// For Win32 implementation, see async-io-win32.c++.
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-unix.h"
#include "async-unix.h"
#include "debug.h"
#include "debug.h"
#include "thread.h"
#include "thread.h"
...
@@ -461,11 +462,12 @@ public:
...
@@ -461,11 +462,12 @@ public:
}
}
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
);
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
);
// Perform a DNS lookup.
// Perform a DNS lookup.
static
Promise
<
Array
<
SocketAddress
>>
parse
(
static
Promise
<
Array
<
SocketAddress
>>
parse
(
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// TODO(someday): Allow commas in `str`.
// TODO(someday): Allow commas in `str`.
SocketAddress
result
;
SocketAddress
result
;
...
@@ -480,6 +482,12 @@ public:
...
@@ -480,6 +482,12 @@ public:
result
.
addr
.
unixDomain
.
sun_family
=
AF_UNIX
;
result
.
addr
.
unixDomain
.
sun_family
=
AF_UNIX
;
strcpy
(
result
.
addr
.
unixDomain
.
sun_path
,
path
.
cStr
());
strcpy
(
result
.
addr
.
unixDomain
.
sun_path
,
path
.
cStr
());
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"unix sockets blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -495,6 +503,12 @@ public:
...
@@ -495,6 +503,12 @@ public:
// NULL terminator so that we can safely read it back in toString
// NULL terminator so that we can safely read it back in toString
memcpy
(
result
.
addr
.
unixDomain
.
sun_path
+
1
,
path
.
cStr
(),
path
.
size
()
+
1
);
memcpy
(
result
.
addr
.
unixDomain
.
sun_path
+
1
,
path
.
cStr
(),
path
.
size
()
+
1
);
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
result
.
addrlen
=
offsetof
(
struct
sockaddr_un
,
sun_path
)
+
path
.
size
()
+
1
;
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"abstract unix sockets blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -547,7 +561,8 @@ public:
...
@@ -547,7 +561,8 @@ public:
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
// Not a number. Maybe it's a service name. Fall back to DNS.
// Not a number. Maybe it's a service name. Fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
,
filter
);
}
}
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
}
else
{
}
else
{
...
@@ -569,6 +584,7 @@ public:
...
@@ -569,6 +584,7 @@ public:
result
.
addr
.
inet6
.
sin6_family
=
AF_INET6
;
result
.
addr
.
inet6
.
sin6_family
=
AF_INET6
;
result
.
addr
.
inet6
.
sin6_port
=
htons
(
port
);
result
.
addr
.
inet6
.
sin6_port
=
htons
(
port
);
#endif
#endif
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
...
@@ -597,13 +613,18 @@ public:
...
@@ -597,13 +613,18 @@ public:
switch
(
inet_pton
(
af
,
buffer
,
addrTarget
))
{
switch
(
inet_pton
(
af
,
buffer
,
addrTarget
))
{
case
1
:
{
case
1
:
{
// success.
// success.
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"address family blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
}
}
case
0
:
case
0
:
// It's apparently not a simple address... fall back to DNS.
// It's apparently not a simple address... fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
,
filter
);
default
:
default
:
KJ_FAIL_SYSCALL
(
"inet_pton"
,
errno
,
af
,
addrPart
);
KJ_FAIL_SYSCALL
(
"inet_pton"
,
errno
,
af
,
addrPart
);
}
}
...
@@ -616,6 +637,14 @@ public:
...
@@ -616,6 +637,14 @@ public:
return
result
;
return
result
;
}
}
bool
allowedBy
(
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllow
(
&
addr
.
generic
,
addrlen
);
}
bool
parseAllowedBy
(
const
_
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllowParse
(
&
addr
.
generic
);
}
private
:
private
:
SocketAddress
()
:
addrlen
(
0
)
{
SocketAddress
()
:
addrlen
(
0
)
{
memset
(
&
addr
,
0
,
sizeof
(
addr
));
memset
(
&
addr
,
0
,
sizeof
(
addr
));
...
@@ -640,8 +669,9 @@ class SocketAddress::LookupReader {
...
@@ -640,8 +669,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
// getaddrinfo.
public
:
public
:
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
)
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
,
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
))
{}
_
::
NetworkFilter
&
filter
)
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
)),
filter
(
filter
)
{}
~
LookupReader
()
{
~
LookupReader
()
{
if
(
thread
)
thread
->
detach
();
if
(
thread
)
thread
->
detach
();
...
@@ -654,7 +684,7 @@ public:
...
@@ -654,7 +684,7 @@ public:
thread
=
nullptr
;
thread
=
nullptr
;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
// anyway.
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no addresses."
)
{
break
;
}
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no
permitted
addresses."
)
{
break
;
}
return
addresses
.
releaseAsArray
();
return
addresses
.
releaseAsArray
();
}
else
{
}
else
{
// getaddrinfo() can return multiple copies of the same address for several reasons.
// getaddrinfo() can return multiple copies of the same address for several reasons.
...
@@ -667,7 +697,9 @@ public:
...
@@ -667,7 +697,9 @@ public:
//
//
// So we instead resort to de-duping results.
// So we instead resort to de-duping results.
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
alreadySeen
.
insert
(
current
).
second
)
{
addresses
.
add
(
current
);
if
(
current
.
parseAllowedBy
(
filter
))
{
addresses
.
add
(
current
);
}
}
}
return
read
();
return
read
();
}
}
...
@@ -677,6 +709,7 @@ public:
...
@@ -677,6 +709,7 @@ public:
private
:
private
:
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
AsyncInputStream
>
input
;
kj
::
Own
<
AsyncInputStream
>
input
;
_
::
NetworkFilter
&
filter
;
SocketAddress
current
;
SocketAddress
current
;
kj
::
Vector
<
SocketAddress
>
addresses
;
kj
::
Vector
<
SocketAddress
>
addresses
;
std
::
set
<
SocketAddress
>
alreadySeen
;
std
::
set
<
SocketAddress
>
alreadySeen
;
...
@@ -688,7 +721,8 @@ struct SocketAddress::LookupParams {
...
@@ -688,7 +721,8 @@ struct SocketAddress::LookupParams {
};
};
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
// the only cross-platform DNS API and it is blocking.
//
//
...
@@ -773,7 +807,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -773,7 +807,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}
}));
}));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
)
,
filter
);
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
}
}
...
@@ -781,22 +815,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -781,22 +815,33 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFileDescriptor
{
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFileDescriptor
{
public
:
public
:
FdConnectionReceiver
(
UnixEventPort
&
eventPort
,
int
fd
,
uint
flags
)
FdConnectionReceiver
(
UnixEventPort
&
eventPort
,
int
fd
,
:
OwnedFileDescriptor
(
fd
,
flags
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFileDescriptor
(
fd
,
flags
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
)
{}
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
)
{}
Promise
<
Own
<
AsyncIoStream
>>
accept
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
accept
()
override
{
int
newFd
;
int
newFd
;
struct
sockaddr_storage
addr
;
socklen_t
addrlen
=
sizeof
(
addr
);
retry
:
retry
:
#if __linux__ && !__BIONIC__
#if __linux__ && !__BIONIC__
newFd
=
::
accept4
(
fd
,
nullptr
,
nullptr
,
SOCK_NONBLOCK
|
SOCK_CLOEXEC
);
newFd
=
::
accept4
(
fd
,
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
&
addrlen
,
SOCK_NONBLOCK
|
SOCK_CLOEXEC
);
#else
#else
newFd
=
::
accept
(
fd
,
nullptr
,
nullptr
);
newFd
=
::
accept
(
fd
,
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
&
addrlen
);
#endif
#endif
if
(
newFd
>=
0
)
{
if
(
newFd
>=
0
)
{
return
Own
<
AsyncIoStream
>
(
heap
<
AsyncStreamFd
>
(
eventPort
,
newFd
,
NEW_FD_FLAGS
));
if
(
!
filter
.
shouldAllow
(
reinterpret_cast
<
struct
sockaddr
*>
(
&
addr
),
addrlen
))
{
// Drop disallowed address.
close
(
newFd
);
return
accept
();
}
else
{
return
Own
<
AsyncIoStream
>
(
heap
<
AsyncStreamFd
>
(
eventPort
,
newFd
,
NEW_FD_FLAGS
));
}
}
else
{
}
else
{
int
error
=
errno
;
int
error
=
errno
;
...
@@ -849,13 +894,15 @@ public:
...
@@ -849,13 +894,15 @@ public:
public
:
public
:
UnixEventPort
&
eventPort
;
UnixEventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
UnixEventPort
::
FdObserver
observer
;
UnixEventPort
::
FdObserver
observer
;
};
};
class
DatagramPortImpl
final
:
public
DatagramPort
,
public
OwnedFileDescriptor
{
class
DatagramPortImpl
final
:
public
DatagramPort
,
public
OwnedFileDescriptor
{
public
:
public
:
DatagramPortImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
UnixEventPort
&
eventPort
,
int
fd
,
uint
flags
)
DatagramPortImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
UnixEventPort
&
eventPort
,
int
fd
,
:
OwnedFileDescriptor
(
fd
,
flags
),
lowLevel
(
lowLevel
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFileDescriptor
(
fd
,
flags
),
lowLevel
(
lowLevel
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
|
observer
(
eventPort
,
fd
,
UnixEventPort
::
FdObserver
::
OBSERVE_READ
|
UnixEventPort
::
FdObserver
::
OBSERVE_WRITE
)
{}
UnixEventPort
::
FdObserver
::
OBSERVE_WRITE
)
{}
...
@@ -883,6 +930,7 @@ public:
...
@@ -883,6 +930,7 @@ public:
public
:
public
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
UnixEventPort
&
eventPort
;
UnixEventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
UnixEventPort
::
FdObserver
observer
;
UnixEventPort
::
FdObserver
observer
;
};
};
...
@@ -935,11 +983,13 @@ public:
...
@@ -935,11 +983,13 @@ public:
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
}));
}));
}
}
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
int
fd
,
uint
flags
=
0
)
override
{
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
flags
);
int
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
filter
,
flags
);
}
}
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
int
fd
,
uint
flags
=
0
)
override
{
Own
<
DatagramPort
>
wrapDatagramSocketFd
(
return
heap
<
DatagramPortImpl
>
(
*
this
,
eventPort
,
fd
,
flags
);
int
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
DatagramPortImpl
>
(
*
this
,
eventPort
,
fd
,
filter
,
flags
);
}
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
...
@@ -956,12 +1006,14 @@ private:
...
@@ -956,12 +1006,14 @@ private:
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
public
:
public
:
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
Array
<
SocketAddress
>
addrs
)
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
:
lowLevel
(
lowLevel
),
addrs
(
kj
::
mv
(
addrs
))
{}
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
Array
<
SocketAddress
>
addrs
)
:
lowLevel
(
lowLevel
),
filter
(
filter
),
addrs
(
kj
::
mv
(
addrs
))
{}
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
promise
=
connectImpl
(
lowLevel
,
addrsCopy
);
auto
promise
=
connectImpl
(
lowLevel
,
filter
,
addrsCopy
);
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
}
}
...
@@ -988,7 +1040,7 @@ public:
...
@@ -988,7 +1040,7 @@ public:
KJ_SYSCALL
(
::
listen
(
fd
,
SOMAXCONN
));
KJ_SYSCALL
(
::
listen
(
fd
,
SOMAXCONN
));
}
}
return
lowLevel
.
wrapListenSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapListenSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
...
@@ -1011,11 +1063,11 @@ public:
...
@@ -1011,11 +1063,11 @@ public:
addrs
[
0
].
bind
(
fd
);
addrs
[
0
].
bind
(
fd
);
}
}
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
NetworkAddress
>
clone
()
override
{
Own
<
NetworkAddress
>
clone
()
override
{
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
kj
::
heapArray
(
addrs
.
asPtr
()));
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
heapArray
(
addrs
.
asPtr
()));
}
}
String
toString
()
override
{
String
toString
()
override
{
...
@@ -1029,26 +1081,33 @@ public:
...
@@ -1029,26 +1081,33 @@ public:
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Array
<
SocketAddress
>
addrs
;
Array
<
SocketAddress
>
addrs
;
uint
counter
=
0
;
uint
counter
=
0
;
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
KJ_ASSERT
(
addrs
.
size
()
>
0
);
KJ_ASSERT
(
addrs
.
size
()
>
0
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
return
kj
::
evalNow
([
&
]()
{
return
kj
::
evalNow
([
&
]()
->
Promise
<
Own
<
AsyncIoStream
>>
{
return
lowLevel
.
wrapConnectingSocketFd
(
if
(
!
addrs
[
0
].
allowedBy
(
filter
))
{
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
return
KJ_EXCEPTION
(
FAILED
,
"connect() blocked by restrictPeers()"
);
}
else
{
return
lowLevel
.
wrapConnectingSocketFd
(
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
}
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Success, pass along.
// Success, pass along.
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
},
[
&
lowLevel
,
addrs
](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
},
[
&
lowLevel
,
&
filter
,
addrs
](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Connect failed.
// Connect failed.
if
(
addrs
.
size
()
>
1
)
{
if
(
addrs
.
size
()
>
1
)
{
// Try the next address instead.
// Try the next address instead.
return
connectImpl
(
lowLevel
,
addrs
.
slice
(
1
,
addrs
.
size
()));
return
connectImpl
(
lowLevel
,
filter
,
addrs
.
slice
(
1
,
addrs
.
size
()));
}
else
{
}
else
{
// No more addresses to try, so propagate the exception.
// No more addresses to try, so propagate the exception.
return
kj
::
mv
(
exception
);
return
kj
::
mv
(
exception
);
...
@@ -1060,25 +1119,35 @@ private:
...
@@ -1060,25 +1119,35 @@ private:
class
SocketNetwork
final
:
public
Network
{
class
SocketNetwork
final
:
public
Network
{
public
:
public
:
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
SocketNetwork
&
parent
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
)
:
lowLevel
(
parent
.
lowLevel
),
filter
(
allow
,
deny
,
parent
.
filter
)
{}
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
auto
&
lowLevelCopy
=
lowLevel
;
return
evalLater
(
mvCapture
(
heapString
(
addr
),
[
this
,
portHint
](
String
&&
addr
)
{
return
evalLater
(
mvCapture
(
heapString
(
addr
),
return
SocketAddress
::
parse
(
lowLevel
,
addr
,
portHint
,
filter
);
[
&
lowLevelCopy
,
portHint
](
String
&&
addr
)
{
})).
then
([
this
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
SocketAddress
::
parse
(
lowLevelCopy
,
addr
,
portHint
);
return
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
mv
(
addresses
));
})).
then
([
&
lowLevelCopy
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
heap
<
NetworkAddressImpl
>
(
lowLevelCopy
,
kj
::
mv
(
addresses
));
});
});
}
}
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
array
.
finish
()));
KJ_REQUIRE
(
array
[
0
].
allowedBy
(
filter
),
"address blocked by restrictPeers()"
)
{
break
;
}
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
array
.
finish
()));
}
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
override
{
return
heap
<
SocketNetwork
>
(
*
this
,
allow
,
deny
);
}
}
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
_
::
NetworkFilter
filter
;
};
};
// =======================================================================================
// =======================================================================================
...
@@ -1189,10 +1258,16 @@ public:
...
@@ -1189,10 +1258,16 @@ public:
return
receive
();
return
receive
();
});
});
}
else
{
}
else
{
if
(
!
port
.
filter
.
shouldAllow
(
reinterpret_cast
<
const
struct
sockaddr
*>
(
msg
.
msg_name
),
msg
.
msg_namelen
))
{
// Ignore message from disallowed source.
return
receive
();
}
receivedSize
=
n
;
receivedSize
=
n
;
contentTruncated
=
msg
.
msg_flags
&
MSG_TRUNC
;
contentTruncated
=
msg
.
msg_flags
&
MSG_TRUNC
;
source
.
emplace
(
port
.
lowLevel
,
msg
.
msg_name
,
msg
.
msg_namelen
);
source
.
emplace
(
port
.
lowLevel
,
port
.
filter
,
msg
.
msg_name
,
msg
.
msg_namelen
);
ancillaryList
.
resize
(
0
);
ancillaryList
.
resize
(
0
);
ancillaryTruncated
=
msg
.
msg_flags
&
MSG_CTRUNC
;
ancillaryTruncated
=
msg
.
msg_flags
&
MSG_CTRUNC
;
...
@@ -1250,9 +1325,10 @@ private:
...
@@ -1250,9 +1325,10 @@ private:
bool
ancillaryTruncated
=
false
;
bool
ancillaryTruncated
=
false
;
struct
StoredAddress
{
struct
StoredAddress
{
StoredAddress
(
LowLevelAsyncIoProvider
&
lowLevel
,
const
void
*
sockaddr
,
uint
length
)
StoredAddress
(
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
const
void
*
sockaddr
,
uint
length
)
:
raw
(
sockaddr
,
length
),
:
raw
(
sockaddr
,
length
),
abstract
(
lowLevel
,
Array
<
SocketAddress
>
(
&
raw
,
1
,
NullArrayDisposer
::
instance
))
{}
abstract
(
lowLevel
,
filter
,
Array
<
SocketAddress
>
(
&
raw
,
1
,
NullArrayDisposer
::
instance
))
{}
SocketAddress
raw
;
SocketAddress
raw
;
NetworkAddressImpl
abstract
;
NetworkAddressImpl
abstract
;
...
...
c++/src/kj/async-io-win32.c++
View file @
505e71f7
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#define _WIN32_WINNT 0x0600
#define _WIN32_WINNT 0x0600
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h"
#include "async-win32.h"
#include "debug.h"
#include "debug.h"
#include "thread.h"
#include "thread.h"
...
@@ -524,11 +525,12 @@ public:
...
@@ -524,11 +525,12 @@ public:
}
}
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
static
Promise
<
Array
<
SocketAddress
>>
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
);
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
);
// Perform a DNS lookup.
// Perform a DNS lookup.
static
Promise
<
Array
<
SocketAddress
>>
parse
(
static
Promise
<
Array
<
SocketAddress
>>
parse
(
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
StringPtr
str
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// TODO(someday): Allow commas in `str`.
// TODO(someday): Allow commas in `str`.
SocketAddress
result
;
SocketAddress
result
;
...
@@ -580,7 +582,8 @@ public:
...
@@ -580,7 +582,8 @@ public:
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
port
=
strtoul
(
portText
->
cStr
(),
&
endptr
,
0
);
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
if
(
portText
->
size
()
==
0
||
*
endptr
!=
'\0'
)
{
// Not a number. Maybe it's a service name. Fall back to DNS.
// Not a number. Maybe it's a service name. Fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
kj
::
heapString
(
*
portText
),
portHint
,
filter
);
}
}
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
KJ_REQUIRE
(
port
<
65536
,
"Port number too large."
);
}
else
{
}
else
{
...
@@ -622,13 +625,18 @@ public:
...
@@ -622,13 +625,18 @@ public:
switch
(
InetPtonA
(
af
,
buffer
,
addrTarget
))
{
switch
(
InetPtonA
(
af
,
buffer
,
addrTarget
))
{
case
1
:
{
case
1
:
{
// success.
// success.
if
(
!
result
.
parseAllowedBy
(
filter
))
{
KJ_FAIL_REQUIRE
(
"address family blocked by restrictPeers()"
);
return
Array
<
SocketAddress
>
();
}
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
result
);
array
.
add
(
result
);
return
array
.
finish
();
return
array
.
finish
();
}
}
case
0
:
case
0
:
// It's apparently not a simple address... fall back to DNS.
// It's apparently not a simple address... fall back to DNS.
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
);
return
lookupHost
(
lowLevel
,
kj
::
heapString
(
addrPart
),
nullptr
,
port
,
filter
);
default
:
default
:
KJ_FAIL_WIN32
(
"InetPton"
,
WSAGetLastError
(),
af
,
addrPart
);
KJ_FAIL_WIN32
(
"InetPton"
,
WSAGetLastError
(),
af
,
addrPart
);
}
}
...
@@ -641,6 +649,14 @@ public:
...
@@ -641,6 +649,14 @@ public:
return
result
;
return
result
;
}
}
bool
allowedBy
(
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllow
(
&
addr
.
generic
,
addrlen
);
}
bool
parseAllowedBy
(
const
_
::
NetworkFilter
&
filter
)
{
return
filter
.
shouldAllowParse
(
&
addr
.
generic
);
}
static
SocketAddress
getWildcardForFamily
(
int
family
)
{
static
SocketAddress
getWildcardForFamily
(
int
family
)
{
SocketAddress
result
;
SocketAddress
result
;
switch
(
family
)
{
switch
(
family
)
{
...
@@ -680,8 +696,9 @@ class SocketAddress::LookupReader {
...
@@ -680,8 +696,9 @@ class SocketAddress::LookupReader {
// getaddrinfo.
// getaddrinfo.
public
:
public
:
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
)
LookupReader
(
kj
::
Own
<
Thread
>&&
thread
,
kj
::
Own
<
AsyncInputStream
>&&
input
,
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
))
{}
_
::
NetworkFilter
&
filter
)
:
thread
(
kj
::
mv
(
thread
)),
input
(
kj
::
mv
(
input
)),
filter
(
filter
)
{}
~
LookupReader
()
{
~
LookupReader
()
{
if
(
thread
)
thread
->
detach
();
if
(
thread
)
thread
->
detach
();
...
@@ -694,7 +711,7 @@ public:
...
@@ -694,7 +711,7 @@ public:
thread
=
nullptr
;
thread
=
nullptr
;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
// anyway.
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no addresses."
)
{
break
;
}
KJ_REQUIRE
(
addresses
.
size
()
>
0
,
"DNS lookup returned no
permitted
addresses."
)
{
break
;
}
return
addresses
.
releaseAsArray
();
return
addresses
.
releaseAsArray
();
}
else
{
}
else
{
// getaddrinfo() can return multiple copies of the same address for several reasons.
// getaddrinfo() can return multiple copies of the same address for several reasons.
...
@@ -707,7 +724,9 @@ public:
...
@@ -707,7 +724,9 @@ public:
//
//
// So we instead resort to de-duping results.
// So we instead resort to de-duping results.
if
(
alreadySeen
.
insert
(
current
).
second
)
{
if
(
alreadySeen
.
insert
(
current
).
second
)
{
addresses
.
add
(
current
);
if
(
current
.
parseAllowedBy
(
filter
))
{
addresses
.
add
(
current
);
}
}
}
return
read
();
return
read
();
}
}
...
@@ -717,6 +736,7 @@ public:
...
@@ -717,6 +736,7 @@ public:
private
:
private
:
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
Thread
>
thread
;
kj
::
Own
<
AsyncInputStream
>
input
;
kj
::
Own
<
AsyncInputStream
>
input
;
_
::
NetworkFilter
&
filter
;
SocketAddress
current
;
SocketAddress
current
;
kj
::
Vector
<
SocketAddress
>
addresses
;
kj
::
Vector
<
SocketAddress
>
addresses
;
std
::
set
<
SocketAddress
>
alreadySeen
;
std
::
set
<
SocketAddress
>
alreadySeen
;
...
@@ -728,7 +748,8 @@ struct SocketAddress::LookupParams {
...
@@ -728,7 +748,8 @@ struct SocketAddress::LookupParams {
};
};
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
Promise
<
Array
<
SocketAddress
>>
SocketAddress
::
lookupHost
(
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
kj
::
String
host
,
kj
::
String
service
,
uint
portHint
,
_
::
NetworkFilter
&
filter
)
{
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
// the only cross-platform DNS API and it is blocking.
//
//
...
@@ -818,7 +839,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -818,7 +839,7 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
}
}
}));
}));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
));
auto
reader
=
heap
<
LookupReader
>
(
kj
::
mv
(
thread
),
kj
::
mv
(
input
)
,
filter
);
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
return
reader
->
read
().
attach
(
kj
::
mv
(
reader
));
}
}
...
@@ -826,8 +847,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
...
@@ -826,8 +847,9 @@ Promise<Array<SocketAddress>> SocketAddress::lookupHost(
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFd
{
class
FdConnectionReceiver
final
:
public
ConnectionReceiver
,
public
OwnedFd
{
public
:
public
:
FdConnectionReceiver
(
Win32EventPort
&
eventPort
,
SOCKET
fd
,
uint
flags
)
FdConnectionReceiver
(
Win32EventPort
&
eventPort
,
SOCKET
fd
,
:
OwnedFd
(
fd
,
flags
),
eventPort
(
eventPort
),
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
:
OwnedFd
(
fd
,
flags
),
eventPort
(
eventPort
),
filter
(
filter
),
observer
(
eventPort
.
observeIo
(
reinterpret_cast
<
HANDLE
>
(
fd
))),
observer
(
eventPort
.
observeIo
(
reinterpret_cast
<
HANDLE
>
(
fd
))),
address
(
SocketAddress
::
getLocalAddress
(
fd
))
{
address
(
SocketAddress
::
getLocalAddress
(
fd
))
{
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
...
@@ -858,8 +880,9 @@ public:
...
@@ -858,8 +880,9 @@ public:
}
}
}
}
return
op
->
onComplete
().
attach
(
kj
::
mv
(
scratch
)).
then
(
mvCapture
(
result
,
return
op
->
onComplete
().
then
(
mvCapture
(
result
,
mvCapture
(
scratch
,
[
this
](
Own
<
AsyncIoStream
>
stream
,
Win32EventPort
::
IoResult
ioResult
)
{
[
this
](
Array
<
byte
>
scratch
,
Own
<
AsyncIoStream
>
stream
,
Win32EventPort
::
IoResult
ioResult
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
if
(
ioResult
.
errorCode
!=
ERROR_SUCCESS
)
{
if
(
ioResult
.
errorCode
!=
ERROR_SUCCESS
)
{
KJ_FAIL_WIN32
(
"AcceptEx()"
,
ioResult
.
errorCode
)
{
break
;
}
KJ_FAIL_WIN32
(
"AcceptEx()"
,
ioResult
.
errorCode
)
{
break
;
}
}
else
{
}
else
{
...
@@ -867,8 +890,17 @@ public:
...
@@ -867,8 +890,17 @@ public:
stream
->
setsockopt
(
SOL_SOCKET
,
SO_UPDATE_ACCEPT_CONTEXT
,
stream
->
setsockopt
(
SOL_SOCKET
,
SO_UPDATE_ACCEPT_CONTEXT
,
reinterpret_cast
<
char
*>
(
&
me
),
sizeof
(
me
));
reinterpret_cast
<
char
*>
(
&
me
),
sizeof
(
me
));
}
}
return
kj
::
mv
(
stream
);
}));
auto
addr
=
reinterpret_cast
<
struct
sockaddr
*>
(
scratch
.
begin
()
+
128
);
size_t
addrlen
=
addr
->
sa_family
==
AF_INET
?
sizeof
(
struct
sockaddr_in
)
:
sizeof
(
struct
sockaddr_in6
);
if
(
filter
.
shouldAllow
(
addr
,
addrlen
))
{
return
kj
::
mv
(
stream
);
}
else
{
return
accept
();
}
})));
}
}
uint
getPort
()
override
{
uint
getPort
()
override
{
...
@@ -888,6 +920,7 @@ public:
...
@@ -888,6 +920,7 @@ public:
public
:
public
:
Win32EventPort
&
eventPort
;
Win32EventPort
&
eventPort
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Own
<
Win32EventPort
::
IoObserver
>
observer
;
Own
<
Win32EventPort
::
IoObserver
>
observer
;
LPFN_ACCEPTEX
acceptEx
=
nullptr
;
LPFN_ACCEPTEX
acceptEx
=
nullptr
;
SocketAddress
address
;
SocketAddress
address
;
...
@@ -923,8 +956,9 @@ public:
...
@@ -923,8 +956,9 @@ public:
return
kj
::
mv
(
result
);
return
kj
::
mv
(
result
);
}));
}));
}
}
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
SOCKET
fd
,
uint
flags
=
0
)
override
{
Own
<
ConnectionReceiver
>
wrapListenSocketFd
(
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
flags
);
SOCKET
fd
,
NetworkFilter
&
filter
,
uint
flags
=
0
)
override
{
return
heap
<
FdConnectionReceiver
>
(
eventPort
,
fd
,
filter
,
flags
);
}
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
Timer
&
getTimer
()
override
{
return
eventPort
.
getTimer
();
}
...
@@ -941,12 +975,14 @@ private:
...
@@ -941,12 +975,14 @@ private:
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
class
NetworkAddressImpl
final
:
public
NetworkAddress
{
public
:
public
:
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
Array
<
SocketAddress
>
addrs
)
NetworkAddressImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
:
lowLevel
(
lowLevel
),
addrs
(
kj
::
mv
(
addrs
))
{}
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
Array
<
SocketAddress
>
addrs
)
:
lowLevel
(
lowLevel
),
filter
(
filter
),
addrs
(
kj
::
mv
(
addrs
))
{}
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
Promise
<
Own
<
AsyncIoStream
>>
connect
()
override
{
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
addrsCopy
=
heapArray
(
addrs
.
asPtr
());
auto
promise
=
connectImpl
(
lowLevel
,
addrsCopy
);
auto
promise
=
connectImpl
(
lowLevel
,
filter
,
addrsCopy
);
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
return
promise
.
attach
(
kj
::
mv
(
addrsCopy
));
}
}
...
@@ -974,7 +1010,7 @@ public:
...
@@ -974,7 +1010,7 @@ public:
KJ_WINSOCK
(
::
listen
(
fd
,
SOMAXCONN
));
KJ_WINSOCK
(
::
listen
(
fd
,
SOMAXCONN
));
}
}
return
lowLevel
.
wrapListenSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapListenSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
Own
<
DatagramPort
>
bindDatagramPort
()
override
{
...
@@ -998,11 +1034,11 @@ public:
...
@@ -998,11 +1034,11 @@ public:
addrs
[
0
].
bind
(
fd
);
addrs
[
0
].
bind
(
fd
);
}
}
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
NEW_FD_FLAGS
);
return
lowLevel
.
wrapDatagramSocketFd
(
fd
,
filter
,
NEW_FD_FLAGS
);
}
}
Own
<
NetworkAddress
>
clone
()
override
{
Own
<
NetworkAddress
>
clone
()
override
{
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
kj
::
heapArray
(
addrs
.
asPtr
()));
return
kj
::
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
heapArray
(
addrs
.
asPtr
()));
}
}
String
toString
()
override
{
String
toString
()
override
{
...
@@ -1016,26 +1052,34 @@ public:
...
@@ -1016,26 +1052,34 @@ public:
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
;
Array
<
SocketAddress
>
addrs
;
Array
<
SocketAddress
>
addrs
;
uint
counter
=
0
;
uint
counter
=
0
;
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
static
Promise
<
Own
<
AsyncIoStream
>>
connectImpl
(
LowLevelAsyncIoProvider
&
lowLevel
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
LowLevelAsyncIoProvider
&
lowLevel
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
ArrayPtr
<
SocketAddress
>
addrs
)
{
KJ_ASSERT
(
addrs
.
size
()
>
0
);
KJ_ASSERT
(
addrs
.
size
()
>
0
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
int
fd
=
addrs
[
0
].
socket
(
SOCK_STREAM
);
return
kj
::
evalNow
([
&
]()
{
return
kj
::
evalNow
([
&
]()
->
Promise
<
Own
<
AsyncIoStream
>>
{
return
lowLevel
.
wrapConnectingSocketFd
(
if
(
!
addrs
[
0
].
allowedBy
(
filter
))
{
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
return
KJ_EXCEPTION
(
FAILED
,
"connect() blocked by restrictPeers()"
);
}
else
{
return
lowLevel
.
wrapConnectingSocketFd
(
fd
,
addrs
[
0
].
getRaw
(),
addrs
[
0
].
getRawSize
(),
NEW_FD_FLAGS
);
}
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
}).
then
([](
Own
<
AsyncIoStream
>&&
stream
)
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Success, pass along.
// Success, pass along.
return
kj
::
mv
(
stream
);
return
kj
::
mv
(
stream
);
},
[
&
lowLevel
,
KJ_CPCAP
(
addrs
)](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
},
[
&
lowLevel
,
&
filter
,
KJ_CPCAP
(
addrs
)](
Exception
&&
exception
)
mutable
->
Promise
<
Own
<
AsyncIoStream
>>
{
// Connect failed.
// Connect failed.
if
(
addrs
.
size
()
>
1
)
{
if
(
addrs
.
size
()
>
1
)
{
// Try the next address instead.
// Try the next address instead.
return
connectImpl
(
lowLevel
,
addrs
.
slice
(
1
,
addrs
.
size
()));
return
connectImpl
(
lowLevel
,
filter
,
addrs
.
slice
(
1
,
addrs
.
size
()));
}
else
{
}
else
{
// No more addresses to try, so propagate the exception.
// No more addresses to try, so propagate the exception.
return
kj
::
mv
(
exception
);
return
kj
::
mv
(
exception
);
...
@@ -1047,25 +1091,35 @@ private:
...
@@ -1047,25 +1091,35 @@ private:
class
SocketNetwork
final
:
public
Network
{
class
SocketNetwork
final
:
public
Network
{
public
:
public
:
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
LowLevelAsyncIoProvider
&
lowLevel
)
:
lowLevel
(
lowLevel
)
{}
explicit
SocketNetwork
(
SocketNetwork
&
parent
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
)
:
lowLevel
(
parent
.
lowLevel
),
filter
(
allow
,
deny
,
parent
.
filter
)
{}
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
Promise
<
Own
<
NetworkAddress
>>
parseAddress
(
StringPtr
addr
,
uint
portHint
=
0
)
override
{
auto
&
lowLevelCopy
=
lowLevel
;
return
evalLater
(
mvCapture
(
heapString
(
addr
),
[
this
,
portHint
](
String
&&
addr
)
{
return
evalLater
(
mvCapture
(
heapString
(
addr
),
return
SocketAddress
::
parse
(
lowLevel
,
addr
,
portHint
,
filter
);
[
&
lowLevelCopy
,
portHint
](
String
&&
addr
)
{
})).
then
([
this
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
SocketAddress
::
parse
(
lowLevelCopy
,
addr
,
portHint
);
return
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
kj
::
mv
(
addresses
));
})).
then
([
&
lowLevelCopy
](
Array
<
SocketAddress
>
addresses
)
->
Own
<
NetworkAddress
>
{
return
heap
<
NetworkAddressImpl
>
(
lowLevelCopy
,
kj
::
mv
(
addresses
));
});
});
}
}
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
Own
<
NetworkAddress
>
getSockaddr
(
const
void
*
sockaddr
,
uint
len
)
override
{
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
auto
array
=
kj
::
heapArrayBuilder
<
SocketAddress
>
(
1
);
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
array
.
add
(
SocketAddress
(
sockaddr
,
len
));
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
array
.
finish
()));
KJ_REQUIRE
(
array
[
0
].
allowedBy
(
filter
),
"address blocked by restrictPeers()"
)
{
break
;
}
return
Own
<
NetworkAddress
>
(
heap
<
NetworkAddressImpl
>
(
lowLevel
,
filter
,
array
.
finish
()));
}
Own
<
Network
>
restrictPeers
(
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
allow
,
kj
::
ArrayPtr
<
const
kj
::
StringPtr
>
deny
=
nullptr
)
override
{
return
heap
<
SocketNetwork
>
(
*
this
,
allow
,
deny
);
}
}
private
:
private
:
LowLevelAsyncIoProvider
&
lowLevel
;
LowLevelAsyncIoProvider
&
lowLevel
;
_
::
NetworkFilter
filter
;
};
};
// =======================================================================================
// =======================================================================================
...
...
c++/src/kj/async-io.c++
View file @
505e71f7
...
@@ -19,15 +19,18 @@
...
@@ -19,15 +19,18 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// THE SOFTWARE.
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#endif
#include "async-io.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-io-internal.h"
#include "debug.h"
#include "debug.h"
#include "vector.h"
#include "vector.h"
#if _WIN32
#if _WIN32
// Request Vista-level APIs.
#define WINVER 0x0600
#define _WIN32_WINNT 0x0600
#include <winsock2.h>
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include <ws2tcpip.h>
...
@@ -205,7 +208,8 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
...
@@ -205,7 +208,8 @@ void DatagramPort::setsockopt(int level, int option, const void* value, uint len
Own
<
DatagramPort
>
NetworkAddress
::
bindDatagramPort
()
{
Own
<
DatagramPort
>
NetworkAddress
::
bindDatagramPort
()
{
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
}
}
Own
<
DatagramPort
>
LowLevelAsyncIoProvider
::
wrapDatagramSocketFd
(
Fd
fd
,
uint
flags
)
{
Own
<
DatagramPort
>
LowLevelAsyncIoProvider
::
wrapDatagramSocketFd
(
Fd
fd
,
LowLevelAsyncIoProvider
::
NetworkFilter
&
filter
,
uint
flags
)
{
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
KJ_UNIMPLEMENTED
(
"Datagram sockets not implemented."
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment