Commit 9dc88268 authored by Kenton Varda's avatar Kenton Varda

CapabilityServerSet::getLocalServer() really needs to throw if given a broken promise capability.

parent f1ccc201
...@@ -864,7 +864,10 @@ TEST(Capability, CapabilityServerSet) { ...@@ -864,7 +864,10 @@ TEST(Capability, CapabilityServerSet) {
auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>(); auto paf = kj::newPromiseAndFulfiller<test::TestInterface::Client>();
test::TestInterface::Client clientPromise = kj::mv(paf.promise); test::TestInterface::Client clientPromise = kj::mv(paf.promise);
bool resolved1 = false, resolved2 = false; auto errorPaf = kj::newPromiseAndFulfiller<test::TestInterface::Client>();
test::TestInterface::Client errorPromise = kj::mv(errorPaf.promise);
bool resolved1 = false, resolved2 = false, resolved3 = false;
auto promise1 = set1.getLocalServer(clientPromise) auto promise1 = set1.getLocalServer(clientPromise)
.then([&](kj::Maybe<test::TestInterface::Server&> server) { .then([&](kj::Maybe<test::TestInterface::Server&> server) {
resolved1 = true; resolved1 = true;
...@@ -875,6 +878,13 @@ TEST(Capability, CapabilityServerSet) { ...@@ -875,6 +878,13 @@ TEST(Capability, CapabilityServerSet) {
resolved2 = true; resolved2 = true;
EXPECT_TRUE(server == nullptr); EXPECT_TRUE(server == nullptr);
}); });
auto promise3 = set1.getLocalServer(errorPromise)
.then([&](kj::Maybe<test::TestInterface::Server&> server) {
KJ_FAIL_EXPECT("getLocalServer() on error promise should have thrown");
}, [&](kj::Exception&& e) {
resolved3 = true;
KJ_EXPECT(e.getDescription().endsWith("foo"), e.getDescription());
});
kj::evalLater([](){}).wait(waitScope); kj::evalLater([](){}).wait(waitScope);
kj::evalLater([](){}).wait(waitScope); kj::evalLater([](){}).wait(waitScope);
...@@ -883,14 +893,18 @@ TEST(Capability, CapabilityServerSet) { ...@@ -883,14 +893,18 @@ TEST(Capability, CapabilityServerSet) {
EXPECT_FALSE(resolved1); EXPECT_FALSE(resolved1);
EXPECT_FALSE(resolved2); EXPECT_FALSE(resolved2);
EXPECT_FALSE(resolved3);
paf.fulfiller->fulfill(kj::cp(client1)); paf.fulfiller->fulfill(kj::cp(client1));
errorPaf.fulfiller->reject(KJ_EXCEPTION(FAILED, "foo"));
promise1.wait(waitScope); promise1.wait(waitScope);
promise2.wait(waitScope); promise2.wait(waitScope);
promise3.wait(waitScope);
EXPECT_TRUE(resolved1); EXPECT_TRUE(resolved1);
EXPECT_TRUE(resolved2); EXPECT_TRUE(resolved2);
EXPECT_TRUE(resolved3);
// Check weak pointer. // Check weak pointer.
{ {
......
...@@ -53,6 +53,8 @@ public: ...@@ -53,6 +53,8 @@ public:
static BrokenCapFactoryImpl brokenCapFactory; static BrokenCapFactoryImpl brokenCapFactory;
static kj::Own<ClientHook> newNullCap();
} // namespace } // namespace
ClientHook::ClientHook() { ClientHook::ClientHook() {
...@@ -71,7 +73,7 @@ void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTa ...@@ -71,7 +73,7 @@ void MessageReader::initCapTable(kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTa
// ======================================================================================= // =======================================================================================
Capability::Client::Client(decltype(nullptr)) Capability::Client::Client(decltype(nullptr))
: hook(newBrokenCap("Called null capability.")) {} : hook(newNullCap()) {}
Capability::Client::Client(kj::Exception&& exception) Capability::Client::Client(kj::Exception&& exception)
: hook(newBrokenCap(kj::mv(exception))) {} : hook(newBrokenCap(kj::mv(exception))) {}
...@@ -603,9 +605,10 @@ public: ...@@ -603,9 +605,10 @@ public:
class BrokenClient final: public ClientHook, public kj::Refcounted { class BrokenClient final: public ClientHook, public kj::Refcounted {
public: public:
BrokenClient(const kj::Exception& exception): exception(exception) {} BrokenClient(const kj::Exception& exception, bool resolved)
BrokenClient(const kj::StringPtr description) : exception(exception), resolved(resolved) {}
: exception(kj::Exception::Type::FAILED, "", 0, kj::str(description)) {} BrokenClient(const kj::StringPtr description, bool resolved)
: exception(kj::Exception::Type::FAILED, "", 0, kj::str(description)), resolved(resolved) {}
Request<AnyPointer, AnyPointer> newCall( Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
...@@ -622,7 +625,11 @@ public: ...@@ -622,7 +625,11 @@ public:
} }
kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception)); if (resolved) {
return nullptr;
} else {
return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception));
}
} }
kj::Own<ClientHook> addRef() override { kj::Own<ClientHook> addRef() override {
...@@ -635,20 +642,26 @@ public: ...@@ -635,20 +642,26 @@ public:
private: private:
kj::Exception exception; kj::Exception exception;
bool resolved;
}; };
kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) { kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) {
return kj::refcounted<BrokenClient>(exception); return kj::refcounted<BrokenClient>(exception, false);
}
kj::Own<ClientHook> newNullCap() {
// A null capability, unlike other broken capabilities, is considered resolved.
return kj::refcounted<BrokenClient>("Called null capability.", true);
} }
} // namespace } // namespace
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) { kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) {
return kj::refcounted<BrokenClient>(reason); return kj::refcounted<BrokenClient>(reason, false);
} }
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) { kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) {
return kj::refcounted<BrokenClient>(kj::mv(reason)); return kj::refcounted<BrokenClient>(kj::mv(reason), false);
} }
kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) { kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) {
...@@ -705,9 +718,6 @@ kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::C ...@@ -705,9 +718,6 @@ kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::C
.then([this](kj::Own<ClientHook>&& resolved) { .then([this](kj::Own<ClientHook>&& resolved) {
Capability::Client client(kj::mv(resolved)); Capability::Client client(kj::mv(resolved));
return getLocalServerInternal(client); return getLocalServerInternal(client);
}, [](kj::Exception&&) -> void* {
// A broken promise is simply not a local capability.
return nullptr;
}); });
} else { } else {
return hook->getLocalServer(*this); return hook->getLocalServer(*this);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment