Commit 6cf522b2 authored by Kenton Varda's avatar Kenton Varda

Detect RPC messages that are so big that the receiver will reject them, and never send them.

parent d9919210
...@@ -336,6 +336,38 @@ TEST(TwoPartyNetwork, ConvenienceClasses) { ...@@ -336,6 +336,38 @@ TEST(TwoPartyNetwork, ConvenienceClasses) {
EXPECT_EQ(1, callCount); EXPECT_EQ(1, callCount);
} }
TEST(TwoPartyNetwork, HugeMessage) {
auto ioContext = kj::setupAsyncIo();
int callCount = 0;
int handleCount = 0;
auto serverThread = runServer(*ioContext.provider, callCount, handleCount);
TwoPartyVatNetwork network(*serverThread.pipe, rpc::twoparty::Side::CLIENT);
auto rpcClient = makeRpcClient(network);
auto client = getPersistentCap(rpcClient, rpc::twoparty::Side::SERVER,
test::TestSturdyRefObjectId::Tag::TEST_MORE_STUFF).castAs<test::TestMoreStuff>();
// Oversized request fails.
{
auto req = client.methodWithDefaultsRequest();
req.initA(100000000); // 100 MB
KJ_EXPECT_THROW_MESSAGE("larger than the single-message size limit",
req.send().wait(ioContext.waitScope));
}
// Oversized response fails.
KJ_EXPECT_THROW_MESSAGE("larger than the single-message size limit",
client.getEnormousStringRequest().send().wait(ioContext.waitScope));
// Connection is still up.
{
auto req = client.getCallSequenceRequest();
req.setExpected(0);
KJ_EXPECT(req.send().wait(ioContext.waitScope).getN() == 0);
}
}
class TestAuthenticatedBootstrapImpl final class TestAuthenticatedBootstrapImpl final
: public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server { : public test::TestAuthenticatedBootstrap<rpc::twoparty::VatId>::Server {
public: public:
......
...@@ -82,6 +82,17 @@ public: ...@@ -82,6 +82,17 @@ public:
} }
void send() override { void send() override {
size_t size = 0;
for (auto& segment: message.getSegmentsForOutput()) {
size += segment.size();
}
KJ_REQUIRE(size < ReaderOptions().traversalLimitInWords, size,
"Trying to send Cap'n Proto message larger than the single-message size limit. The "
"other side probably won't accept it and would abort the connection, so I won't "
"send it.") {
return;
}
network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down") network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
.then([&]() { .then([&]() {
// Note that if the write fails, all further writes will be skipped due to the exception. // Note that if the write fails, all further writes will be skipped due to the exception.
......
...@@ -115,7 +115,26 @@ kj::Exception toException(const rpc::Exception::Reader& exception) { ...@@ -115,7 +115,26 @@ kj::Exception toException(const rpc::Exception::Reader& exception) {
void fromException(const kj::Exception& exception, rpc::Exception::Builder builder) { void fromException(const kj::Exception& exception, rpc::Exception::Builder builder) {
// TODO(someday): Indicate the remote server name as part of the stack trace. Maybe even // TODO(someday): Indicate the remote server name as part of the stack trace. Maybe even
// transmit stack traces? // transmit stack traces?
builder.setReason(exception.getDescription());
kj::StringPtr description = exception.getDescription();
// Include context, if any.
kj::Vector<kj::String> contextLines;
for (auto context = exception.getContext();;) {
KJ_IF_MAYBE(c, context) {
contextLines.add(kj::str("context: ", c->file, ": ", c->line, ": ", c->description));
context = c->next;
} else {
break;
}
}
kj::String scratch;
if (contextLines.size() > 0) {
scratch = kj::str(description, '\n', kj::strArray(contextLines, "\n"));
description = scratch;
}
builder.setReason(description);
builder.setType(static_cast<rpc::Exception::Type>(exception.getType())); builder.setType(static_cast<rpc::Exception::Type>(exception.getType()));
if (exception.getType() == kj::Exception::Type::FAILED && if (exception.getType() == kj::Exception::Type::FAILED &&
...@@ -1490,7 +1509,14 @@ private: ...@@ -1490,7 +1509,14 @@ private:
if (isTailCall) { if (isTailCall) {
callBuilder.getSendResultsTo().setYourself(); callBuilder.getSendResultsTo().setYourself();
} }
message->send(); KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
KJ_CONTEXT("sending RPC call",
callBuilder.getInterfaceId(), callBuilder.getMethodId());
message->send();
})) {
KJ_LOG(WARNING, *exception);
kj::throwRecoverableException(kj::mv(*exception));
}
// Make the result promise. // Make the result promise.
SendInternalResult result; SendInternalResult result;
...@@ -1716,9 +1742,12 @@ private: ...@@ -1716,9 +1742,12 @@ private:
kj::Own<IncomingRpcMessage>&& request, kj::Own<IncomingRpcMessage>&& request,
kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTableArray, kj::Array<kj::Maybe<kj::Own<ClientHook>>> capTableArray,
const AnyPointer::Reader& params, const AnyPointer::Reader& params,
bool redirectResults, kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller) bool redirectResults, kj::Own<kj::PromiseFulfiller<void>>&& cancelFulfiller,
uint64_t interfaceId, uint16_t methodId)
: connectionState(kj::addRef(connectionState)), : connectionState(kj::addRef(connectionState)),
answerId(answerId), answerId(answerId),
interfaceId(interfaceId),
methodId(methodId),
requestSize(request->getBody().targetSize().wordCount), requestSize(request->getBody().targetSize().wordCount),
request(kj::mv(request)), request(kj::mv(request)),
paramsCapTable(kj::mv(capTableArray)), paramsCapTable(kj::mv(capTableArray)),
...@@ -1784,7 +1813,18 @@ private: ...@@ -1784,7 +1813,18 @@ private:
returnMessage.setAnswerId(answerId); returnMessage.setAnswerId(answerId);
returnMessage.setReleaseParamCaps(false); returnMessage.setReleaseParamCaps(false);
auto exports = kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send(); kj::Maybe<kj::Array<ExportId>> exports;
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
// Debug info incase send() fails due to overside message.
KJ_CONTEXT("returning from RPC call", interfaceId, methodId);
exports = kj::downcast<RpcServerResponseImpl>(*KJ_ASSERT_NONNULL(response)).send();
})) {
KJ_LOG(WARNING, *exception);
responseSent = false;
sendErrorReturn(kj::mv(*exception));
return;
}
KJ_IF_MAYBE(e, exports) { KJ_IF_MAYBE(e, exports) {
// Caps were returned, so we can't free the pipeline yet. // Caps were returned, so we can't free the pipeline yet.
cleanupAnswerTable(kj::mv(*e), false); cleanupAnswerTable(kj::mv(*e), false);
...@@ -1936,6 +1976,10 @@ private: ...@@ -1936,6 +1976,10 @@ private:
kj::Own<RpcConnectionState> connectionState; kj::Own<RpcConnectionState> connectionState;
AnswerId answerId; AnswerId answerId;
uint64_t interfaceId;
uint16_t methodId;
// For debugging.
// Request --------------------------------------------- // Request ---------------------------------------------
size_t requestSize; // for flow limit purposes size_t requestSize; // for flow limit purposes
...@@ -2265,7 +2309,8 @@ private: ...@@ -2265,7 +2309,8 @@ private:
auto context = kj::refcounted<RpcCallContext>( auto context = kj::refcounted<RpcCallContext>(
*this, answerId, kj::mv(message), kj::mv(capTableArray), payload.getContent(), *this, answerId, kj::mv(message), kj::mv(capTableArray), payload.getContent(),
redirectResults, kj::mv(cancelPaf.fulfiller)); redirectResults, kj::mv(cancelPaf.fulfiller),
call.getInterfaceId(), call.getMethodId());
// No more using `call` after this point, as it now belongs to the context. // No more using `call` after this point, as it now belongs to the context.
...@@ -2317,7 +2362,7 @@ private: ...@@ -2317,7 +2362,7 @@ private:
contextPtr->sendReturn(); contextPtr->sendReturn();
}, [contextPtr](kj::Exception&& exception) { }, [contextPtr](kj::Exception&& exception) {
contextPtr->sendErrorReturn(kj::mv(exception)); contextPtr->sendErrorReturn(kj::mv(exception));
}).then([]() {}, [&](kj::Exception&& exception) { }).catch_([&](kj::Exception&& exception) {
// Handle exceptions that occur in sendReturn()/sendErrorReturn(). // Handle exceptions that occur in sendReturn()/sendErrorReturn().
taskFailed(kj::mv(exception)); taskFailed(kj::mv(exception));
}).attach(kj::mv(context)) }).attach(kj::mv(context))
......
...@@ -1116,6 +1116,11 @@ kj::Promise<void> TestMoreStuffImpl::getNull(GetNullContext context) { ...@@ -1116,6 +1116,11 @@ kj::Promise<void> TestMoreStuffImpl::getNull(GetNullContext context) {
return kj::READY_NOW; return kj::READY_NOW;
} }
kj::Promise<void> TestMoreStuffImpl::getEnormousString(GetEnormousStringContext context) {
context.getResults().initStr(100000000); // 100MB
return kj::READY_NOW;
}
#endif // !CAPNP_LITE #endif // !CAPNP_LITE
} // namespace _ (private) } // namespace _ (private)
......
...@@ -286,6 +286,8 @@ public: ...@@ -286,6 +286,8 @@ public:
kj::Promise<void> getNull(GetNullContext context) override; kj::Promise<void> getNull(GetNullContext context) override;
kj::Promise<void> getEnormousString(GetEnormousStringContext context) override;
private: private:
int& callCount; int& callCount;
int& handleCount; int& handleCount;
......
...@@ -824,6 +824,9 @@ interface TestMoreStuff extends(TestCallOrder) { ...@@ -824,6 +824,9 @@ interface TestMoreStuff extends(TestCallOrder) {
getNull @10 () -> (nullCap :TestMoreStuff); getNull @10 () -> (nullCap :TestMoreStuff);
# Always returns a null capability. # Always returns a null capability.
getEnormousString @11 () -> (str :Text);
# Attempts to return an 100MB string. Should always fail.
} }
interface TestMembrane { interface TestMembrane {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment