Commit b5f8d487 authored by Kenton Varda's avatar Kenton Varda Committed by GitHub

Merge pull request #378 from sandstorm-io/fix-realm-gateway

Fix RealmGateway to avoid double-transforms on loopback.
parents d3784575 86276457
......@@ -545,6 +545,10 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
return kj::refcounted<QueuedClient>(kj::mv(promise));
}
kj::Own<PipelineHook> newLocalPromisePipeline(kj::Promise<kj::Own<PipelineHook>>&& promise) {
return kj::refcounted<QueuedPipeline>(kj::mv(promise));
}
// =======================================================================================
namespace {
......
......@@ -595,6 +595,10 @@ kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& pro
// the new client. This hook's `getResolved()` and `whenMoreResolved()` methods will reflect the
// redirection to the eventual replacement client.
kj::Own<PipelineHook> newLocalPromisePipeline(kj::Promise<kj::Own<PipelineHook>>&& promise);
// Returns a PipelineHook that queues up calls until `promise` resolves, then forwards them to
// the new pipeline.
kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason);
kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason);
// Helper function that creates a capability which simply throws exceptions when called.
......
......@@ -1147,6 +1147,114 @@ TEST(Rpc, RealmGatewayExport) {
EXPECT_EQ("exported-foo", response.getSturdyRef());
}
TEST(Rpc, RealmGatewayImportExport) {
// Test that a save request which leaves the realm, bounces through a promise capability, and
// then comes back into the realm, does not actually get translated both ways.
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<test::TestSturdyRef>::Client bootstrap = kj::heap<TestPersistent>("foo");
MallocMessageBuilder serverHostIdBuilder;
auto serverHostId = serverHostIdBuilder.getRoot<test::TestSturdyRefHostId>();
serverHostId.setHost("server");
MallocMessageBuilder clientHostIdBuilder;
auto clientHostId = clientHostIdBuilder.getRoot<test::TestSturdyRefHostId>();
clientHostId.setHost("client");
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
TestNetwork network;
TestRestorer restorer;
TestNetworkAdapter& clientNetwork = network.add("client");
TestNetworkAdapter& serverNetwork = network.add("server");
RpcSystem<test::TestSturdyRefHostId> rpcClient =
makeRpcServer(clientNetwork, bootstrap, gateway);
auto paf = kj::newPromiseAndFulfiller<Capability::Client>();
RpcSystem<test::TestSturdyRefHostId> rpcServer =
makeRpcServer(serverNetwork, kj::mv(paf.promise));
auto client = rpcClient.bootstrap(serverHostId).castAs<Persistent<test::TestSturdyRef>>();
bool responseReady = false;
auto responsePromise = client.saveRequest().send()
.then([&](Response<Persistent<test::TestSturdyRef>::SaveResults>&& response) {
responseReady = true;
return kj::mv(response);
}).eagerlyEvaluate(nullptr);
// Crank the event loop to give the message time to reach the server and block on the promise
// resolution.
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(responseReady);
paf.fulfiller->fulfill(rpcServer.bootstrap(clientHostId));
auto response = responsePromise.wait(waitScope);
// Should have the original value. If it went through export and re-import, though, then this
// will be "imported-exported-foo", which is wrong.
EXPECT_EQ("foo", response.getSturdyRef().getObjectId().getAs<Text>());
}
TEST(Rpc, RealmGatewayImportExport) {
// Test that a save request which enters the realm, bounces through a promise capability, and
// then goes back out of the realm, does not actually get translated both ways.
TestRealmGateway::Client gateway = kj::heap<TestGateway>();
Persistent<Text>::Client bootstrap = kj::heap<TestPersistentText>("foo");
MallocMessageBuilder serverHostIdBuilder;
auto serverHostId = serverHostIdBuilder.getRoot<test::TestSturdyRefHostId>();
serverHostId.setHost("server");
MallocMessageBuilder clientHostIdBuilder;
auto clientHostId = clientHostIdBuilder.getRoot<test::TestSturdyRefHostId>();
clientHostId.setHost("client");
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
TestNetwork network;
TestRestorer restorer;
TestNetworkAdapter& clientNetwork = network.add("client");
TestNetworkAdapter& serverNetwork = network.add("server");
RpcSystem<test::TestSturdyRefHostId> rpcClient =
makeRpcServer(clientNetwork, bootstrap);
auto paf = kj::newPromiseAndFulfiller<Capability::Client>();
RpcSystem<test::TestSturdyRefHostId> rpcServer =
makeRpcServer(serverNetwork, kj::mv(paf.promise), gateway);
auto client = rpcClient.bootstrap(serverHostId).castAs<Persistent<Text>>();
bool responseReady = false;
auto responsePromise = client.saveRequest().send()
.then([&](Response<Persistent<Text>::SaveResults>&& response) {
responseReady = true;
return kj::mv(response);
}).eagerlyEvaluate(nullptr);
// Crank the event loop to give the message time to reach the server and block on the promise
// resolution.
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
kj::evalLater([]() {}).wait(waitScope);
EXPECT_FALSE(responseReady);
paf.fulfiller->fulfill(rpcServer.bootstrap(clientHostId));
auto response = responsePromise.wait(waitScope);
// Should have the original value. If it went through import and re-export, though, then this
// will be "exported-imported-foo", which is wrong.
EXPECT_EQ("foo", response.getSturdyRef());
}
} // namespace
} // namespace _ (private)
} // namespace capnp
......@@ -663,7 +663,7 @@ private:
// Implement call() by copying params and results messages.
auto params = context->getParams();
auto request = newCall(interfaceId, methodId, params.targetSize());
auto request = newCallNoIntercept(interfaceId, methodId, params.targetSize());
request.set(params);
context->releaseParams();
......@@ -865,12 +865,42 @@ private:
Request<AnyPointer, AnyPointer> newCall(
uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
if (!isResolved && interfaceId == typeId<Persistent<>>() && methodId == 0 &&
connectionState->gateway != nullptr) {
// This is a call to Persistent.save(), and we're not resolved yet, and the underlying
// remote capability will perform a gateway translation. This isn't right if the promise
// ultimately resolves to a local capability. Instead, we'll need to queue the call until
// the promise resolves.
return newLocalPromiseClient(fork.addBranch())
->newCall(interfaceId, methodId, sizeHint);
}
receivedCall = true;
return cap->newCall(interfaceId, methodId, sizeHint);
}
VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
kj::Own<CallContextHook>&& context) override {
if (!isResolved && interfaceId == typeId<Persistent<>>() && methodId == 0 &&
connectionState->gateway != nullptr) {
// This is a call to Persistent.save(), and we're not resolved yet, and the underlying
// remote capability will perform a gateway translation. This isn't right if the promise
// ultimately resolves to a local capability. Instead, we'll need to queue the call until
// the promise resolves.
auto vpapPromises = fork.addBranch().then(kj::mvCapture(context,
[interfaceId,methodId](kj::Own<CallContextHook>&& context,
kj::Own<ClientHook> resolvedCap) {
auto vpap = resolvedCap->call(interfaceId, methodId, kj::mv(context));
return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline));
})).split();
return {
kj::mv(kj::get<0>(vpapPromises)),
newLocalPromisePipeline(kj::mv(kj::get<1>(vpapPromises))),
};
}
receivedCall = true;
return cap->call(interfaceId, methodId, kj::mv(context));
}
......@@ -2305,6 +2335,31 @@ private:
// Wait, this is a call to Persistent.save() and we need to translate it through our
// gateway.
KJ_IF_MAYBE(resolvedPromise, capability->whenMoreResolved()) {
// The plot thickens: We're looking at a promise capability. It could end up resolving
// to a capability outside the gateway, in which case we don't want to translate at all.
auto promises = resolvedPromise->then(kj::mvCapture(context,
[this,interfaceId,methodId](kj::Own<CallContextHook>&& context,
kj::Own<ClientHook> resolvedCap) {
auto vpap = startCall(interfaceId, methodId, kj::mv(resolvedCap), kj::mv(context));
return kj::tuple(kj::mv(vpap.promise), kj::mv(vpap.pipeline));
})).attach(addRef(*this)).split();
return {
kj::mv(kj::get<0>(promises)),
newLocalPromisePipeline(kj::mv(kj::get<1>(promises))),
};
}
if (capability->getBrand() == this) {
// This capability is one of our own, pointing back out over the network. That means
// that it would be inappropriate to apply the gateway transformation. We just want to
// reflect the call back.
return kj::downcast<RpcClient>(*capability)
.callNoIntercept(interfaceId, methodId, kj::mv(context));
}
auto params = context->getParams().getAs<Persistent<>::SaveParams>();
auto requestSize = params.totalSize();
......@@ -2321,7 +2376,7 @@ private:
}
}
return capability->call(interfaceId, methodId, context->addRef());
return capability->call(interfaceId, methodId, kj::mv(context));
}
kj::Maybe<kj::Own<ClientHook>> getMessageTarget(const rpc::MessageTarget::Reader& target) {
......
......@@ -26,6 +26,7 @@
#ifndef KJ_ASYNC_H_
#error "Do not include this directly; include kj/async.h."
#include "async.h" // help IDE parse this file
#endif
#ifndef KJ_ASYNC_INL_H_
......@@ -444,6 +445,28 @@ public:
}
};
template <typename T, size_t index>
class SplitBranch final: public ForkBranchBase {
// A PromiseNode that implements one branch of a fork -- i.e. one of the branches that receives
// a const reference.
public:
SplitBranch(Own<ForkHubBase>&& hub): ForkBranchBase(kj::mv(hub)) {}
typedef kj::Decay<decltype(kj::get<index>(kj::instance<T>()))> Element;
void get(ExceptionOrValue& output) noexcept override {
ExceptionOr<T>& hubResult = getHubResultRef().template as<T>();
KJ_IF_MAYBE(value, hubResult.value) {
output.as<Element>().value = kj::mv(kj::get<index>(*value));
} else {
output.as<Element>().value = nullptr;
}
output.exception = hubResult.exception;
releaseHub(output);
}
};
// -------------------------------------------------------------------
class ForkHubBase: public Refcounted, protected Event {
......@@ -479,8 +502,24 @@ public:
return Promise<_::UnfixVoid<T>>(false, kj::heap<ForkBranch<T>>(addRef(*this)));
}
_::SplitTuplePromise<T> split() {
return splitImpl(MakeIndexes<tupleSize<T>()>());
}
private:
ExceptionOr<T> result;
template <size_t... indexes>
_::SplitTuplePromise<T> splitImpl(Indexes<indexes...>) {
return kj::tuple(addSplit<indexes>()...);
}
template <size_t index>
Promise<JoinPromises<typename SplitBranch<T, index>::Element>> addSplit() {
return Promise<JoinPromises<typename SplitBranch<T, index>::Element>>(
false, maybeChain(kj::heap<SplitBranch<T, index>>(addRef(*this)),
implicitCast<typename SplitBranch<T, index>::Element*>(nullptr)));
}
};
inline ExceptionOrValue& ForkBranchBase::getHubResultRef() {
......@@ -833,6 +872,11 @@ Promise<T> ForkedPromise<T>::addBranch() {
return hub->addBranch();
}
template <typename T>
_::SplitTuplePromise<T> Promise<T>::split() {
return refcounted<_::ForkHub<_::FixVoid<T>>>(kj::mv(node))->split();
}
template <typename T>
Promise<T> Promise<T>::exclusiveJoin(Promise<T>&& other) {
return Promise(false, heap<_::ExclusiveJoinPromiseNode>(kj::mv(node), kj::mv(other.node)));
......
......@@ -30,6 +30,7 @@
#endif
#include "exception.h"
#include "tuple.h"
namespace kj {
......@@ -86,6 +87,17 @@ using ReturnType = typename ReturnType_<Func, T>::Type;
// The return type of functor Func given a parameter of type T, with the special exception that if
// T is void, this is the return type of Func called with no arguments.
template <typename T> struct SplitTuplePromise_ { typedef Promise<T> Type; };
template <typename... T>
struct SplitTuplePromise_<kj::_::Tuple<T...>> {
typedef kj::Tuple<Promise<JoinPromises<T>>...> Type;
};
template <typename T>
using SplitTuplePromise = typename SplitTuplePromise_<T>::Type;
// T -> Promise<T>
// Tuple<T> -> Tuple<Promise<T>>
struct Void {};
// Application code should NOT refer to this! See `kj::READY_NOW` instead.
......
......@@ -492,6 +492,21 @@ TEST(Async, ForkRef) {
EXPECT_EQ(789, branch2.wait(waitScope));
}
TEST(Async, Split) {
EventLoop loop;
WaitScope waitScope(loop);
Promise<Tuple<int, String, Promise<int>>> promise = evalLater([&]() {
return kj::tuple(123, str("foo"), Promise<int>(321));
});
Tuple<Promise<int>, Promise<String>, Promise<int>> split = promise.split();
EXPECT_EQ(123, get<0>(split).wait(waitScope));
EXPECT_EQ("foo", get<1>(split).wait(waitScope));
EXPECT_EQ(321, get<2>(split).wait(waitScope));
}
TEST(Async, ExclusiveJoin) {
{
EventLoop loop;
......
......@@ -29,7 +29,6 @@
#include "async-prelude.h"
#include "exception.h"
#include "refcount.h"
#include "tuple.h"
namespace kj {
......@@ -242,6 +241,12 @@ public:
// `Own<U>`, `U` must have a method `Own<U> addRef()` which returns a new reference to the same
// (or an equivalent) object (probably implemented via reference counting).
_::SplitTuplePromise<T> split();
// Split a promise for a tuple into a tuple of promises.
//
// E.g. if you have `Promise<kj::Tuple<T, U>>`, `split()` returns
// `kj::Tuple<Promise<T>, Promise<U>>`.
Promise<T> exclusiveJoin(Promise<T>&& other) KJ_WARN_UNUSED_RESULT;
// Return a new promise that resolves when either the original promise resolves or `other`
// resolves (whichever comes first). The promise that didn't resolve first is canceled.
......
......@@ -350,6 +350,15 @@ inline auto apply(Func&& func, Params&&... params)
return _::expandAndApply(kj::fwd<Func>(func), kj::fwd<Params>(params)...);
}
template <typename T> struct TupleSize_ { static constexpr size_t size = 1; };
template <typename... T> struct TupleSize_<_::Tuple<T...>> {
static constexpr size_t size = sizeof...(T);
};
template <typename T>
constexpr size_t tupleSize() { return TupleSize_<T>::size; }
// Returns size of the tuple T.
} // namespace kj
#endif // KJ_TUPLE_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment