Commit 177fb269 authored by Kenton Varda's avatar Kenton Varda

Fixes #115

parent e2dd39b3
...@@ -58,6 +58,7 @@ public: ...@@ -58,6 +58,7 @@ public:
public: public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) = 0; virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) = 0;
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual kj::Promise<void> shutdown() = 0;
}; };
virtual kj::Maybe<kj::Own<Connection>> baseConnect(_::StructReader vatId) = 0; virtual kj::Maybe<kj::Own<Connection>> baseConnect(_::StructReader vatId) = 0;
virtual kj::Promise<kj::Own<Connection>> baseAccept() = 0; virtual kj::Promise<kj::Own<Connection>> baseAccept() = 0;
......
...@@ -302,9 +302,14 @@ public: ...@@ -302,9 +302,14 @@ public:
} }
if (messages.empty()) { if (messages.empty()) {
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>(); KJ_IF_MAYBE(f, fulfillOnEnd) {
fulfillers.push(kj::mv(paf.fulfiller)); f->get()->fulfill();
return kj::mv(paf.promise); return kj::Maybe<kj::Own<IncomingRpcMessage>>(nullptr);
} else {
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
}
} else { } else {
++network.received; ++network.received;
auto result = kj::mv(messages.front()); auto result = kj::mv(messages.front());
...@@ -312,6 +317,15 @@ public: ...@@ -312,6 +317,15 @@ public:
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result)); return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
} }
} }
kj::Promise<void> shutdown() override {
KJ_IF_MAYBE(p, partner) {
auto paf = kj::newPromiseAndFulfiller<void>();
p->fulfillOnEnd = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
} else {
return kj::READY_NOW;
}
}
void taskFailed(kj::Exception&& exception) override { void taskFailed(kj::Exception&& exception) override {
ADD_FAILURE() << kj::str(exception).cStr(); ADD_FAILURE() << kj::str(exception).cStr();
...@@ -326,6 +340,7 @@ public: ...@@ -326,6 +340,7 @@ public:
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers; std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages; std::queue<kj::Own<IncomingRpcMessage>> messages;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> fulfillOnEnd;
kj::Own<kj::TaskSet> tasks; kj::Own<kj::TaskSet> tasks;
}; };
...@@ -975,6 +990,32 @@ TEST(Rpc, CallBrokenPromise) { ...@@ -975,6 +990,32 @@ TEST(Rpc, CallBrokenPromise) {
getCallSequence(client, 1).wait(context.waitScope); getCallSequence(client, 1).wait(context.waitScope);
} }
TEST(Rpc, Abort) {
// Verify that aborts are received.
TestContext context;
MallocMessageBuilder refMessage(128);
auto hostId = refMessage.initRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
auto conn = KJ_ASSERT_NONNULL(context.clientNetwork.connect(hostId));
{
// Send an invalid message (Return to non-existent question).
auto msg = conn->newOutgoingMessage(128);
auto body = msg->getBody().initAs<rpc::Message>().initReturn();
body.setAnswerId(1234);
body.setCanceled();
msg->send();
}
auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(context.waitScope));
EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());
EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr);
}
// ======================================================================================= // =======================================================================================
typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway; typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "rpc-twoparty.h" #include "rpc-twoparty.h"
#include "test-util.h" #include "test-util.h"
#include <capnp/rpc.capnp.h>
#include <kj/async-unix.h> #include <kj/async-unix.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/thread.h> #include <kj/thread.h>
...@@ -259,6 +260,37 @@ TEST(TwoPartyNetwork, Release) { ...@@ -259,6 +260,37 @@ TEST(TwoPartyNetwork, Release) {
EXPECT_EQ(0, handleCount); EXPECT_EQ(0, handleCount);
} }
TEST(TwoPartyNetwork, Abort) {
// Verify that aborts are received.
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
MallocMessageBuilder refMessage(128);
auto hostId = refMessage.initRoot<rpc::twoparty::VatId>();
hostId.setSide(rpc::twoparty::Side::SERVER);
auto conn = KJ_ASSERT_NONNULL(network.connect(hostId));
{
// Send an invalid message (Return to non-existent question).
auto msg = conn->newOutgoingMessage(128);
auto body = msg->getBody().initAs<rpc::Message>().initReturn();
body.setAnswerId(1234);
body.setCanceled();
msg->send();
}
auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope));
EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());
EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr);
}
} // namespace } // namespace
} // namespace _ } // namespace _
} // namespace capnp } // namespace capnp
...@@ -81,7 +81,8 @@ public: ...@@ -81,7 +81,8 @@ public:
} }
void send() override { void send() override {
network.previousWrite = network.previousWrite.then([&]() { network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
.then([&]() {
// Note that if the write fails, all further writes will be skipped due to the exception. // Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well // We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there. // and it's cleaner to handle the failure there.
...@@ -132,4 +133,12 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI ...@@ -132,4 +133,12 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
}); });
} }
kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite();
});
previousWrite = nullptr;
return kj::mv(result);
}
} // namespace capnp } // namespace capnp
...@@ -73,8 +73,9 @@ private: ...@@ -73,8 +73,9 @@ private:
ReaderOptions receiveOptions; ReaderOptions receiveOptions;
bool accepted = false; bool accepted = false;
kj::Promise<void> previousWrite; kj::Maybe<kj::Promise<void>> previousWrite;
// Resolves when the previous write completes. This effectively serves as the write queue. // Resolves when the previous write completes. This effectively serves as the write queue.
// Becomes null when shutdown() is called.
kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller; kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller;
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the // Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
...@@ -103,6 +104,7 @@ private: ...@@ -103,6 +104,7 @@ private:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override; kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override;
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override; kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
kj::Promise<void> shutdown() override;
}; };
} // namespace capnp } // namespace capnp
......
...@@ -234,11 +234,16 @@ private: ...@@ -234,11 +234,16 @@ private:
class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Refcounted { class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Refcounted {
public: public:
struct DisconnectInfo {
kj::Promise<void> shutdownPromise;
// Task which is working on sending an abort message and cleanly ending the connection.
};
RpcConnectionState(kj::Maybe<Capability::Client> bootstrapInterface, RpcConnectionState(kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway, kj::Maybe<RealmGateway<>::Client> gateway,
kj::Maybe<SturdyRefRestorerBase&> restorer, kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connectionParam, kj::Own<VatNetworkBase::Connection>&& connectionParam,
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller) kj::Own<kj::PromiseFulfiller<DisconnectInfo>>&& disconnectFulfiller)
: bootstrapInterface(kj::mv(bootstrapInterface)), gateway(kj::mv(gateway)), : bootstrapInterface(kj::mv(bootstrapInterface)), gateway(kj::mv(gateway)),
restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), tasks(*this) { restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), tasks(*this) {
connection.init<Connected>(kj::mv(connectionParam)); connection.init<Connected>(kj::mv(connectionParam));
...@@ -354,7 +359,8 @@ public: ...@@ -354,7 +359,8 @@ public:
}); });
// Indicate disconnect. // Indicate disconnect.
disconnectFulfiller->fulfill(); disconnectFulfiller->fulfill(DisconnectInfo {
connection.get<Connected>()->shutdown().attach(kj::mv(connection.get<Connected>())) });
connection.init<Disconnected>(kj::mv(networkException)); connection.init<Disconnected>(kj::mv(networkException));
} }
...@@ -492,7 +498,7 @@ private: ...@@ -492,7 +498,7 @@ private:
// Once the connection has failed, we drop it and replace it with an exception, which will be // Once the connection has failed, we drop it and replace it with an exception, which will be
// thrown from all further calls. // thrown from all further calls.
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller; kj::Own<kj::PromiseFulfiller<DisconnectInfo>> disconnectFulfiller;
ExportTable<ExportId, Export> exports; ExportTable<ExportId, Export> exports;
ExportTable<QuestionId, Question> questions; ExportTable<QuestionId, Question> questions;
...@@ -2641,9 +2647,11 @@ private: ...@@ -2641,9 +2647,11 @@ private:
auto iter = connections.find(connection); auto iter = connections.find(connection);
if (iter == connections.end()) { if (iter == connections.end()) {
VatNetworkBase::Connection* connectionPtr = connection; VatNetworkBase::Connection* connectionPtr = connection;
auto onDisconnect = kj::newPromiseAndFulfiller<void>(); auto onDisconnect = kj::newPromiseAndFulfiller<RpcConnectionState::DisconnectInfo>();
tasks.add(onDisconnect.promise.then([this,connectionPtr]() { tasks.add(onDisconnect.promise
.then([this,connectionPtr](RpcConnectionState::DisconnectInfo info) {
connections.erase(connectionPtr); connections.erase(connectionPtr);
tasks.add(kj::mv(info.shutdownPromise));
})); }));
auto newState = kj::refcounted<RpcConnectionState>( auto newState = kj::refcounted<RpcConnectionState>(
bootstrapInterface, gateway, restorer, kj::mv(connection), bootstrapInterface, gateway, restorer, kj::mv(connection),
......
...@@ -261,7 +261,7 @@ struct Message { ...@@ -261,7 +261,7 @@ struct Message {
obsoleteDelete @9 :AnyPointer; obsoleteDelete @9 :AnyPointer;
# Obsolete way to delete a SturdyRef. This was never implemented, therefore it has been # Obsolete way to delete a SturdyRef. This was never implemented, therefore it has been
# reduted to AnyPointer. This operation was never implemented. # reduted to AnyPointer.
# Level 3 features ----------------------------------------------- # Level 3 features -----------------------------------------------
......
...@@ -274,6 +274,10 @@ public: ...@@ -274,6 +274,10 @@ public:
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0; virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the read stream cleanly terminates, // Wait for a message to be received and return it. If the read stream cleanly terminates,
// return null. If any other problem occurs, throw an exception. // return null. If any other problem occurs, throw an exception.
virtual kj::Promise<void> shutdown() = 0;
// Waits until all outgoing messages have been sent, then shuts down the outgoing stream. The
// returned promise resolves after shutdown is complete.
}; };
// Level 0 features ------------------------------------------------ // Level 0 features ------------------------------------------------
......
...@@ -181,10 +181,10 @@ namespace kj { ...@@ -181,10 +181,10 @@ namespace kj {
(*({ \ (*({ \
auto _kj_result = ::kj::_::readMaybe(value); \ auto _kj_result = ::kj::_::readMaybe(value); \
if (KJ_UNLIKELY(!_kj_result)) { \ if (KJ_UNLIKELY(!_kj_result)) { \
::kj::_::Debug::Fault(__FILE__, __LINE__, 0, \ ::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \
#value " != nullptr", "" #__VA_ARGS__, __VA_ARGS__).fatal(); \ #value " != nullptr", "" #__VA_ARGS__, __VA_ARGS__).fatal(); \
} \ } \
_kj_result; \ kj::mv(_kj_result); \
})) }))
#define KJ_EXCEPTION(type, ...) \ #define KJ_EXCEPTION(type, ...) \
...@@ -239,10 +239,10 @@ namespace kj { ...@@ -239,10 +239,10 @@ namespace kj {
(*({ \ (*({ \
auto _kj_result = ::kj::_::readMaybe(value); \ auto _kj_result = ::kj::_::readMaybe(value); \
if (KJ_UNLIKELY(!_kj_result)) { \ if (KJ_UNLIKELY(!_kj_result)) { \
::kj::_::Debug::Fault(__FILE__, __LINE__, 0, \ ::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \
#value " != nullptr", #__VA_ARGS__, ##__VA_ARGS__).fatal(); \ #value " != nullptr", #__VA_ARGS__, ##__VA_ARGS__).fatal(); \
} \ } \
_kj_result; \ kj::mv(_kj_result); \
})) }))
#define KJ_EXCEPTION(type, ...) \ #define KJ_EXCEPTION(type, ...) \
......
...@@ -206,8 +206,15 @@ class OwnOwn { ...@@ -206,8 +206,15 @@ class OwnOwn {
public: public:
inline OwnOwn(Own<T>&& value) noexcept: value(kj::mv(value)) {} inline OwnOwn(Own<T>&& value) noexcept: value(kj::mv(value)) {}
#if _MSC_VER
inline Own<T>& operator*() { return value; } inline Own<T>& operator*() { return value; }
inline const Own<T>& operator*() const { return value; } inline const Own<T>& operator*() const { return value; }
#else
inline Own<T>& operator*() & { return value; }
inline const Own<T>& operator*() const & { return value; }
inline Own<T>&& operator*() && { return kj::mv(value); }
inline const Own<T>&& operator*() const && { return kj::mv(value); }
#endif
inline Own<T>* operator->() { return &value; } inline Own<T>* operator->() { return &value; }
inline const Own<T>* operator->() const { return &value; } inline const Own<T>* operator->() const { return &value; }
inline operator Own<T>*() { return value ? &value : nullptr; } inline operator Own<T>*() { return value ? &value : nullptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment