Commit 177fb269 authored by Kenton Varda's avatar Kenton Varda

Fixes #115

parent e2dd39b3
......@@ -58,6 +58,7 @@ public:
public:
virtual kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) = 0;
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
virtual kj::Promise<void> shutdown() = 0;
};
virtual kj::Maybe<kj::Own<Connection>> baseConnect(_::StructReader vatId) = 0;
virtual kj::Promise<kj::Own<Connection>> baseAccept() = 0;
......
......@@ -302,9 +302,14 @@ public:
}
if (messages.empty()) {
KJ_IF_MAYBE(f, fulfillOnEnd) {
f->get()->fulfill();
return kj::Maybe<kj::Own<IncomingRpcMessage>>(nullptr);
} else {
auto paf = kj::newPromiseAndFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>();
fulfillers.push(kj::mv(paf.fulfiller));
return kj::mv(paf.promise);
}
} else {
++network.received;
auto result = kj::mv(messages.front());
......@@ -312,6 +317,15 @@ public:
return kj::Maybe<kj::Own<IncomingRpcMessage>>(kj::mv(result));
}
}
kj::Promise<void> shutdown() override {
KJ_IF_MAYBE(p, partner) {
auto paf = kj::newPromiseAndFulfiller<void>();
p->fulfillOnEnd = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
} else {
return kj::READY_NOW;
}
}
void taskFailed(kj::Exception&& exception) override {
ADD_FAILURE() << kj::str(exception).cStr();
......@@ -326,6 +340,7 @@ public:
std::queue<kj::Own<kj::PromiseFulfiller<kj::Maybe<kj::Own<IncomingRpcMessage>>>>> fulfillers;
std::queue<kj::Own<IncomingRpcMessage>> messages;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> fulfillOnEnd;
kj::Own<kj::TaskSet> tasks;
};
......@@ -975,6 +990,32 @@ TEST(Rpc, CallBrokenPromise) {
getCallSequence(client, 1).wait(context.waitScope);
}
TEST(Rpc, Abort) {
// Verify that aborts are received.
TestContext context;
MallocMessageBuilder refMessage(128);
auto hostId = refMessage.initRoot<test::TestSturdyRefHostId>();
hostId.setHost("server");
auto conn = KJ_ASSERT_NONNULL(context.clientNetwork.connect(hostId));
{
// Send an invalid message (Return to non-existent question).
auto msg = conn->newOutgoingMessage(128);
auto body = msg->getBody().initAs<rpc::Message>().initReturn();
body.setAnswerId(1234);
body.setCanceled();
msg->send();
}
auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(context.waitScope));
EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());
EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr);
}
// =======================================================================================
typedef RealmGateway<test::TestSturdyRef, Text> TestRealmGateway;
......
......@@ -21,6 +21,7 @@
#include "rpc-twoparty.h"
#include "test-util.h"
#include <capnp/rpc.capnp.h>
#include <kj/async-unix.h>
#include <kj/debug.h>
#include <kj/thread.h>
......@@ -259,6 +260,37 @@ TEST(TwoPartyNetwork, Release) {
EXPECT_EQ(0, handleCount);
}
TEST(TwoPartyNetwork, Abort) {
// Verify that aborts are received.
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
MallocMessageBuilder refMessage(128);
auto hostId = refMessage.initRoot<rpc::twoparty::VatId>();
hostId.setSide(rpc::twoparty::Side::SERVER);
auto conn = KJ_ASSERT_NONNULL(network.connect(hostId));
{
// Send an invalid message (Return to non-existent question).
auto msg = conn->newOutgoingMessage(128);
auto body = msg->getBody().initAs<rpc::Message>().initReturn();
body.setAnswerId(1234);
body.setCanceled();
msg->send();
}
auto reply = KJ_ASSERT_NONNULL(conn->receiveIncomingMessage().wait(ioContext.waitScope));
EXPECT_EQ(rpc::Message::ABORT, reply->getBody().getAs<rpc::Message>().which());
EXPECT_TRUE(conn->receiveIncomingMessage().wait(ioContext.waitScope) == nullptr);
}
} // namespace
} // namespace _
} // namespace capnp
......@@ -81,7 +81,8 @@ public:
}
void send() override {
network.previousWrite = network.previousWrite.then([&]() {
network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
.then([&]() {
// Note that if the write fails, all further writes will be skipped due to the exception.
// We never actually handle this exception because we assume the read end will fail as well
// and it's cleaner to handle the failure there.
......@@ -132,4 +133,12 @@ kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveI
});
}
kj::Promise<void> TwoPartyVatNetwork::shutdown() {
kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
stream.shutdownWrite();
});
previousWrite = nullptr;
return kj::mv(result);
}
} // namespace capnp
......@@ -73,8 +73,9 @@ private:
ReaderOptions receiveOptions;
bool accepted = false;
kj::Promise<void> previousWrite;
kj::Maybe<kj::Promise<void>> previousWrite;
// Resolves when the previous write completes. This effectively serves as the write queue.
// Becomes null when shutdown() is called.
kj::Own<kj::PromiseFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>> acceptFulfiller;
// Fulfiller for the promise returned by acceptConnectionAsRefHost() on the client side, or the
......@@ -103,6 +104,7 @@ private:
kj::Own<OutgoingRpcMessage> newOutgoingMessage(uint firstSegmentWordSize) override;
kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() override;
kj::Promise<void> shutdown() override;
};
} // namespace capnp
......
......@@ -234,11 +234,16 @@ private:
class RpcConnectionState final: public kj::TaskSet::ErrorHandler, public kj::Refcounted {
public:
struct DisconnectInfo {
kj::Promise<void> shutdownPromise;
// Task which is working on sending an abort message and cleanly ending the connection.
};
RpcConnectionState(kj::Maybe<Capability::Client> bootstrapInterface,
kj::Maybe<RealmGateway<>::Client> gateway,
kj::Maybe<SturdyRefRestorerBase&> restorer,
kj::Own<VatNetworkBase::Connection>&& connectionParam,
kj::Own<kj::PromiseFulfiller<void>>&& disconnectFulfiller)
kj::Own<kj::PromiseFulfiller<DisconnectInfo>>&& disconnectFulfiller)
: bootstrapInterface(kj::mv(bootstrapInterface)), gateway(kj::mv(gateway)),
restorer(restorer), disconnectFulfiller(kj::mv(disconnectFulfiller)), tasks(*this) {
connection.init<Connected>(kj::mv(connectionParam));
......@@ -354,7 +359,8 @@ public:
});
// Indicate disconnect.
disconnectFulfiller->fulfill();
disconnectFulfiller->fulfill(DisconnectInfo {
connection.get<Connected>()->shutdown().attach(kj::mv(connection.get<Connected>())) });
connection.init<Disconnected>(kj::mv(networkException));
}
......@@ -492,7 +498,7 @@ private:
// Once the connection has failed, we drop it and replace it with an exception, which will be
// thrown from all further calls.
kj::Own<kj::PromiseFulfiller<void>> disconnectFulfiller;
kj::Own<kj::PromiseFulfiller<DisconnectInfo>> disconnectFulfiller;
ExportTable<ExportId, Export> exports;
ExportTable<QuestionId, Question> questions;
......@@ -2641,9 +2647,11 @@ private:
auto iter = connections.find(connection);
if (iter == connections.end()) {
VatNetworkBase::Connection* connectionPtr = connection;
auto onDisconnect = kj::newPromiseAndFulfiller<void>();
tasks.add(onDisconnect.promise.then([this,connectionPtr]() {
auto onDisconnect = kj::newPromiseAndFulfiller<RpcConnectionState::DisconnectInfo>();
tasks.add(onDisconnect.promise
.then([this,connectionPtr](RpcConnectionState::DisconnectInfo info) {
connections.erase(connectionPtr);
tasks.add(kj::mv(info.shutdownPromise));
}));
auto newState = kj::refcounted<RpcConnectionState>(
bootstrapInterface, gateway, restorer, kj::mv(connection),
......
......@@ -261,7 +261,7 @@ struct Message {
obsoleteDelete @9 :AnyPointer;
# Obsolete way to delete a SturdyRef. This was never implemented, therefore it has been
# reduted to AnyPointer. This operation was never implemented.
# reduted to AnyPointer.
# Level 3 features -----------------------------------------------
......
......@@ -274,6 +274,10 @@ public:
virtual kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> receiveIncomingMessage() = 0;
// Wait for a message to be received and return it. If the read stream cleanly terminates,
// return null. If any other problem occurs, throw an exception.
virtual kj::Promise<void> shutdown() = 0;
// Waits until all outgoing messages have been sent, then shuts down the outgoing stream. The
// returned promise resolves after shutdown is complete.
};
// Level 0 features ------------------------------------------------
......
......@@ -181,10 +181,10 @@ namespace kj {
(*({ \
auto _kj_result = ::kj::_::readMaybe(value); \
if (KJ_UNLIKELY(!_kj_result)) { \
::kj::_::Debug::Fault(__FILE__, __LINE__, 0, \
::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \
#value " != nullptr", "" #__VA_ARGS__, __VA_ARGS__).fatal(); \
} \
_kj_result; \
kj::mv(_kj_result); \
}))
#define KJ_EXCEPTION(type, ...) \
......@@ -239,10 +239,10 @@ namespace kj {
(*({ \
auto _kj_result = ::kj::_::readMaybe(value); \
if (KJ_UNLIKELY(!_kj_result)) { \
::kj::_::Debug::Fault(__FILE__, __LINE__, 0, \
::kj::_::Debug::Fault(__FILE__, __LINE__, ::kj::Exception::Type::FAILED, \
#value " != nullptr", #__VA_ARGS__, ##__VA_ARGS__).fatal(); \
} \
_kj_result; \
kj::mv(_kj_result); \
}))
#define KJ_EXCEPTION(type, ...) \
......
......@@ -206,8 +206,15 @@ class OwnOwn {
public:
inline OwnOwn(Own<T>&& value) noexcept: value(kj::mv(value)) {}
#if _MSC_VER
inline Own<T>& operator*() { return value; }
inline const Own<T>& operator*() const { return value; }
#else
inline Own<T>& operator*() & { return value; }
inline const Own<T>& operator*() const & { return value; }
inline Own<T>&& operator*() && { return kj::mv(value); }
inline const Own<T>&& operator*() const && { return kj::mv(value); }
#endif
inline Own<T>* operator->() { return &value; }
inline const Own<T>* operator->() const { return &value; }
inline operator Own<T>*() { return value ? &value : nullptr; }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment