Commit 98d72869 authored by Kenton Varda's avatar Kenton Varda

Fix bugs.

parent 72af96df
...@@ -210,7 +210,7 @@ public: ...@@ -210,7 +210,7 @@ public:
return kj::mv(KJ_ASSERT_NONNULL(contextPtr->response)); return kj::mv(KJ_ASSERT_NONNULL(contextPtr->response));
}); });
// We also want to notify the context that cancellation was requested in this branch is // We also want to notify the context that cancellation was requested if this branch is
// destroyed. // destroyed.
promise.attach(LocalCallContext::Canceler(kj::mv(context))); promise.attach(LocalCallContext::Canceler(kj::mv(context)));
......
...@@ -421,8 +421,7 @@ public: ...@@ -421,8 +421,7 @@ public:
} }
}; };
class RpcTest: public testing::Test { struct TestContext {
protected:
kj::SimpleEventLoop loop; kj::SimpleEventLoop loop;
TestNetwork network; TestNetwork network;
TestRestorer restorer; TestRestorer restorer;
...@@ -431,6 +430,13 @@ protected: ...@@ -431,6 +430,13 @@ protected:
RpcSystem<test::TestSturdyRefHostId> rpcClient; RpcSystem<test::TestSturdyRefHostId> rpcClient;
RpcSystem<test::TestSturdyRefHostId> rpcServer; RpcSystem<test::TestSturdyRefHostId> rpcServer;
TestContext()
: network(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, loop)),
rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {}
Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) { Capability::Client connect(test::TestSturdyRefObjectId::Tag tag) {
MallocMessageBuilder refMessage(128); MallocMessageBuilder refMessage(128);
auto ref = refMessage.initRoot<rpc::SturdyRef>(); auto ref = refMessage.initRoot<rpc::SturdyRef>();
...@@ -440,21 +446,13 @@ protected: ...@@ -440,21 +446,13 @@ protected:
return rpcClient.restore(hostId, ref.getObjectId()); return rpcClient.restore(hostId, ref.getObjectId());
} }
RpcTest()
: network(loop),
clientNetwork(network.add("client")),
serverNetwork(network.add("server")),
rpcClient(makeRpcClient(clientNetwork, loop)),
rpcServer(makeRpcServer(serverNetwork, restorer, loop)) {}
~RpcTest() noexcept {}
// Need to declare this with explicit noexcept otherwise it conflicts with testing::Test::~Test.
// (Urgh, C++11, why did you change this?)
}; };
TEST_F(RpcTest, Basic) { TEST(Rpc, Basic) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_INTERFACE) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_INTERFACE)
.castAs<test::TestInterface>(); .castAs<test::TestInterface>();
auto request1 = client.fooRequest(); auto request1 = client.fooRequest();
...@@ -478,7 +476,7 @@ TEST_F(RpcTest, Basic) { ...@@ -478,7 +476,7 @@ TEST_F(RpcTest, Basic) {
initTestMessage(request2.initS()); initTestMessage(request2.initS());
auto promise2 = request2.send(); auto promise2 = request2.send();
EXPECT_EQ(0, restorer.callCount); EXPECT_EQ(0, context.restorer.callCount);
auto response1 = loop.wait(kj::mv(promise1)); auto response1 = loop.wait(kj::mv(promise1));
...@@ -488,12 +486,15 @@ TEST_F(RpcTest, Basic) { ...@@ -488,12 +486,15 @@ TEST_F(RpcTest, Basic) {
loop.wait(kj::mv(promise3)); loop.wait(kj::mv(promise3));
EXPECT_EQ(2, restorer.callCount); EXPECT_EQ(2, context.restorer.callCount);
EXPECT_TRUE(barFailed); EXPECT_TRUE(barFailed);
} }
TEST_F(RpcTest, Pipelining) { TEST(Rpc, Pipelining) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_PIPELINE)
.castAs<test::TestPipeline>(); .castAs<test::TestPipeline>();
int chainedCallCount = 0; int chainedCallCount = 0;
...@@ -514,7 +515,7 @@ TEST_F(RpcTest, Pipelining) { ...@@ -514,7 +515,7 @@ TEST_F(RpcTest, Pipelining) {
promise = nullptr; // Just to be annoying, drop the original promise. promise = nullptr; // Just to be annoying, drop the original promise.
EXPECT_EQ(0, restorer.callCount); EXPECT_EQ(0, context.restorer.callCount);
EXPECT_EQ(0, chainedCallCount); EXPECT_EQ(0, chainedCallCount);
auto response = loop.wait(kj::mv(pipelinePromise)); auto response = loop.wait(kj::mv(pipelinePromise));
...@@ -523,12 +524,15 @@ TEST_F(RpcTest, Pipelining) { ...@@ -523,12 +524,15 @@ TEST_F(RpcTest, Pipelining) {
auto response2 = loop.wait(kj::mv(pipelinePromise2)); auto response2 = loop.wait(kj::mv(pipelinePromise2));
checkTestMessage(response2); checkTestMessage(response2);
EXPECT_EQ(3, restorer.callCount); EXPECT_EQ(3, context.restorer.callCount);
EXPECT_EQ(1, chainedCallCount); EXPECT_EQ(1, chainedCallCount);
} }
TEST_F(RpcTest, TailCall) { TEST(Rpc, TailCall) {
auto caller = connect(test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER) TestContext context;
auto& loop = context.loop;
auto caller = context.connect(test::TestSturdyRefObjectId::Tag::TEST_TAIL_CALLER)
.castAs<test::TestTailCaller>(); .castAs<test::TestTailCaller>();
int calleeCallCount = 0; int calleeCallCount = 0;
...@@ -556,18 +560,21 @@ TEST_F(RpcTest, TailCall) { ...@@ -556,18 +560,21 @@ TEST_F(RpcTest, TailCall) {
EXPECT_EQ(2, loop.wait(kj::mv(dependentCall2)).getN()); EXPECT_EQ(2, loop.wait(kj::mv(dependentCall2)).getN());
EXPECT_EQ(1, calleeCallCount); EXPECT_EQ(1, calleeCallCount);
EXPECT_EQ(1, restorer.callCount); EXPECT_EQ(1, context.restorer.callCount);
} }
TEST_F(RpcTest, AsyncCancelation) { TEST(Rpc, AsyncCancelation) {
// Tests allowAsyncCancellation(). // Tests allowAsyncCancellation().
TestContext context;
auto& loop = context.loop;
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false; bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; }); auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop); destructionPromise.eagerlyEvaluate(loop);
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr; kj::Promise<void> promise = nullptr;
...@@ -601,12 +608,15 @@ TEST_F(RpcTest, AsyncCancelation) { ...@@ -601,12 +608,15 @@ TEST_F(RpcTest, AsyncCancelation) {
EXPECT_FALSE(returned); EXPECT_FALSE(returned);
} }
TEST_F(RpcTest, SyncCancelation) { TEST(Rpc, SyncCancelation) {
// Tests isCanceled() without allowAsyncCancellation(). // Tests isCanceled() without allowAsyncCancellation().
TestContext context;
auto& loop = context.loop;
int innerCallCount = 0; int innerCallCount = 0;
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
kj::Promise<void> promise = nullptr; kj::Promise<void> promise = nullptr;
...@@ -644,8 +654,11 @@ TEST_F(RpcTest, SyncCancelation) { ...@@ -644,8 +654,11 @@ TEST_F(RpcTest, SyncCancelation) {
EXPECT_FALSE(returned); EXPECT_FALSE(returned);
} }
TEST_F(RpcTest, PromiseResolve) { TEST(Rpc, PromiseResolve) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
int chainedCallCount = 0; int chainedCallCount = 0;
...@@ -667,7 +680,7 @@ TEST_F(RpcTest, PromiseResolve) { ...@@ -667,7 +680,7 @@ TEST_F(RpcTest, PromiseResolve) {
// Make sure getCap() has been called on the server side by sending another call and waiting // Make sure getCap() has been called on the server side by sending another call and waiting
// for it. // for it.
EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN()); EXPECT_EQ(2, loop.wait(client.getCallSequenceRequest().send()).getN());
EXPECT_EQ(3, restorer.callCount); EXPECT_EQ(3, context.restorer.callCount);
// OK, now fulfill the local promise. // OK, now fulfill the local promise.
paf.fulfiller->fulfill(test::TestInterface::Client( paf.fulfiller->fulfill(test::TestInterface::Client(
...@@ -677,18 +690,21 @@ TEST_F(RpcTest, PromiseResolve) { ...@@ -677,18 +690,21 @@ TEST_F(RpcTest, PromiseResolve) {
EXPECT_EQ("bar", loop.wait(kj::mv(promise)).getS()); EXPECT_EQ("bar", loop.wait(kj::mv(promise)).getS());
EXPECT_EQ("bar", loop.wait(kj::mv(promise2)).getS()); EXPECT_EQ("bar", loop.wait(kj::mv(promise2)).getS());
EXPECT_EQ(3, restorer.callCount); EXPECT_EQ(3, context.restorer.callCount);
EXPECT_EQ(2, chainedCallCount); EXPECT_EQ(2, chainedCallCount);
} }
TEST_F(RpcTest, RetainAndRelease) { TEST(Rpc, RetainAndRelease) {
TestContext context;
auto& loop = context.loop;
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
bool destroyed = false; bool destroyed = false;
auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; }); auto destructionPromise = loop.there(kj::mv(paf.promise), [&]() { destroyed = true; });
destructionPromise.eagerlyEvaluate(loop); destructionPromise.eagerlyEvaluate(loop);
{ {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
{ {
...@@ -713,12 +729,12 @@ TEST_F(RpcTest, RetainAndRelease) { ...@@ -713,12 +729,12 @@ TEST_F(RpcTest, RetainAndRelease) {
{ {
// And call it, without any network communications. // And call it, without any network communications.
uint oldSentCount = clientNetwork.getSentCount(); uint oldSentCount = context.clientNetwork.getSentCount();
auto request = capCopy.fooRequest(); auto request = capCopy.fooRequest();
request.setI(123); request.setI(123);
request.setJ(true); request.setJ(true);
EXPECT_EQ("foo", loop.wait(request.send()).getX()); EXPECT_EQ("foo", loop.wait(request.send()).getX());
EXPECT_EQ(oldSentCount, clientNetwork.getSentCount()); EXPECT_EQ(oldSentCount, context.clientNetwork.getSentCount());
} }
{ {
...@@ -744,8 +760,11 @@ TEST_F(RpcTest, RetainAndRelease) { ...@@ -744,8 +760,11 @@ TEST_F(RpcTest, RetainAndRelease) {
EXPECT_TRUE(destroyed); EXPECT_TRUE(destroyed);
} }
TEST_F(RpcTest, Cancel) { TEST(Rpc, Cancel) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
...@@ -775,8 +794,11 @@ TEST_F(RpcTest, Cancel) { ...@@ -775,8 +794,11 @@ TEST_F(RpcTest, Cancel) {
EXPECT_TRUE(destroyed); EXPECT_TRUE(destroyed);
} }
TEST_F(RpcTest, SendTwice) { TEST(Rpc, SendTwice) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
auto paf = kj::newPromiseAndFulfiller<void>(); auto paf = kj::newPromiseAndFulfiller<void>();
...@@ -822,8 +844,11 @@ RemotePromise<test::TestCallOrder::GetCallSequenceResults> getCallSequence( ...@@ -822,8 +844,11 @@ RemotePromise<test::TestCallOrder::GetCallSequenceResults> getCallSequence(
return req.send(); return req.send();
} }
TEST_F(RpcTest, Embargo) { TEST(Rpc, Embargo) {
auto client = connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF) TestContext context;
auto& loop = context.loop;
auto client = context.connect(test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF)
.castAs<test::TestMoreStuff>(); .castAs<test::TestMoreStuff>();
auto cap = test::TestCallOrder::Client(kj::heap<TestCallOrderImpl>(), loop); auto cap = test::TestCallOrder::Client(kj::heap<TestCallOrderImpl>(), loop);
......
...@@ -1162,7 +1162,8 @@ private: ...@@ -1162,7 +1162,8 @@ private:
resolutionChain(kj::mv(resolutionChain)) {} resolutionChain(kj::mv(resolutionChain)) {}
~CapExtractorImpl() noexcept(false) { ~CapExtractorImpl() noexcept(false) {
KJ_ASSERT(retainedCaps.getWithoutLock().size() == 0, KJ_ASSERT(retainedCaps.getWithoutLock().size() == 0 ||
connectionState.tables.lockShared()->networkException != nullptr,
"CapExtractorImpl destroyed without getting a chance to retain the caps!") { "CapExtractorImpl destroyed without getting a chance to retain the caps!") {
break; break;
} }
...@@ -1927,6 +1928,8 @@ private: ...@@ -1927,6 +1928,8 @@ private:
if (isFirstResponder()) { if (isFirstResponder()) {
// We haven't sent a return yet, so we must have been canceled. Send a cancellation return. // We haven't sent a return yet, so we must have been canceled. Send a cancellation return.
unwindDetector.catchExceptionsIfUnwinding([&]() { unwindDetector.catchExceptionsIfUnwinding([&]() {
// Don't send anything if the connection is broken.
if (connectionState->tables.lockShared()->networkException == nullptr) {
auto message = connectionState->connection->newOutgoingMessage( auto message = connectionState->connection->newOutgoingMessage(
requestCapExtractor.retainedListSizeHint(true) + messageSizeHint<rpc::Return>()); requestCapExtractor.retainedListSizeHint(true) + messageSizeHint<rpc::Return>());
auto builder = message->getBody().initAs<rpc::Message>().initReturn(); auto builder = message->getBody().initAs<rpc::Message>().initReturn();
...@@ -1937,13 +1940,16 @@ private: ...@@ -1937,13 +1940,16 @@ private:
builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList)); builder.adoptRetainedCaps(kj::mv(retainedCaps.exportList));
if (redirectResults) { if (redirectResults) {
// The reason we haven't sent a return is because the results were sent somewhere else. // The reason we haven't sent a return is because the results were sent somewhere
// else.
builder.setResultsSentElsewhere(); builder.setResultsSentElsewhere();
} else { } else {
builder.setCanceled(); builder.setCanceled();
} }
message->send(); message->send();
}
cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr); cleanupAnswerTable(connectionState->tables.lockExclusive(), nullptr);
}); });
} }
...@@ -2195,11 +2201,10 @@ private: ...@@ -2195,11 +2201,10 @@ private:
if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) { if (__atomic_load_n(&cancellationFlags, __ATOMIC_RELAXED) & CANCEL_REQUESTED) {
// We are responsible for deleting the answer table entry. Awkwardly, however, the // We are responsible for deleting the answer table entry. Awkwardly, however, the
// answer table may be the only thing holding a reference to the context, and we may even // answer table may be the only thing holding a reference to the context. So we have to
// be called from the continuation represented by answer.asyncOp. So we have to do the // do the actual deletion asynchronously. But we have to remove it from the table *now*,
// actual deletion asynchronously. But we have to remove it from the table *now*, while // while we still hold the lock, because once we send the return message the answer ID is
// we still hold the lock, because once we send the return message the answer ID is free // free for reuse.
// for reuse.
auto promise = connectionState->eventLoop.evalLater([]() {}); auto promise = connectionState->eventLoop.evalLater([]() {});
promise.attach(kj::mv(lock->answers[questionId])); promise.attach(kj::mv(lock->answers[questionId]));
connectionState->tasks.add(kj::mv(promise)); connectionState->tasks.add(kj::mv(promise));
...@@ -2361,32 +2366,38 @@ private: ...@@ -2361,32 +2366,38 @@ private:
auto cancelPaf = kj::newPromiseAndFulfiller<void>(); auto cancelPaf = kj::newPromiseAndFulfiller<void>();
QuestionId questionId = call.getQuestionId(); QuestionId questionId = call.getQuestionId();
// Note: resolutionChainTail couldn't possibly be changing here because we only handle one // Note: resolutionChainTail couldn't possibly be changing here because we only handle one
// message at a time, so we can hold off locking the tables for a bit longer. // message at a time, so we can hold off locking the tables for a bit longer.
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, questionId, kj::mv(message), call.getParams(), *this, questionId, kj::mv(message), call.getParams(),
kj::addRef(*tables.getWithoutLock().resolutionChainTail), kj::addRef(*tables.getWithoutLock().resolutionChainTail),
redirectResults, kj::mv(cancelPaf.fulfiller)); redirectResults, kj::mv(cancelPaf.fulfiller));
auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef());
// No more using `call` after this point! // No more using `call` after this point!
{ {
auto lock = tables.lockExclusive(); auto lock = tables.lockExclusive();
auto& answer = lock->answers[questionId]; auto& answer = lock->answers[questionId];
// We don't want to overwrite an active question because the destructors for the promise and
// pipeline could try to lock our mutex. Of course, we did already fire off the new call
// above, but that's OK because it won't actually ever inspect the Answer table itself, and
// we're about to close the connection anyway.
KJ_REQUIRE(!answer.active, "questionId is already in use") { KJ_REQUIRE(!answer.active, "questionId is already in use") {
return; return;
} }
answer.active = true; answer.active = true;
answer.callContext = *context; answer.callContext = *context;
}
auto promiseAndPipeline = capability->call(
call.getInterfaceId(), call.getMethodId(), context->addRef());
// Things may have changed -- in particular if call() immediately called
// context->directTailCall().
{
auto lock = tables.lockExclusive();
auto& answer = lock->answers[questionId];
answer.pipeline = kj::mv(promiseAndPipeline.pipeline); answer.pipeline = kj::mv(promiseAndPipeline.pipeline);
if (redirectResults) { if (redirectResults) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment