Commit 76e35a7c authored by Kenton Varda's avatar Kenton Varda

Add cheaper way to check size of RPC messages for flow control.

Way back in 538a767e I added `RpcSystem::setFlowLimit()`, a blunt mechanism by which an RPC node can arrange to stop reading new messages from the connection when too many incoming calls are in-flight. This was needed to deal with buggy Sandstorm apps that would stream multi-gigabyte files by doing a zillion writes without waiting, which would then all be queued in the HTTP gateway, causing it to run out of memory.

In implementing that, I inadertently caused the RPC system to do a tree walk on every call message it received, in order to sum up the message size. This is silly, becaues it's much cheaper to sum up the segment sizes. In fact, in the case of a malicious peer, the tree walk is potentially insufficient, because it doesn't count holes in the segments. The tree walk also means that any invalid pointers in the message cause an exception to be thrown even if that pointer is never accessed by the app, which isn't the usual behavior.

I seem to recall this issue coming up in discussion once in the past, but I couldn't find the thread.

For the new streaming feature, we'll be paying attention to the size of outgoing messages. Again, here, it would be nice to compute this size by summing segments without doing a tree walk.

So, this commit adds `sizeInWords()` methods that do this.
parent 99d308ff
...@@ -85,6 +85,16 @@ ReaderArena::ReaderArena(MessageReader* message) ...@@ -85,6 +85,16 @@ ReaderArena::ReaderArena(MessageReader* message)
ReaderArena::~ReaderArena() noexcept(false) {} ReaderArena::~ReaderArena() noexcept(false) {}
size_t ReaderArena::sizeInWords() {
size_t total = segment0.getArray().size();
for (uint i = 0; ; i++) {
SegmentReader* segment = tryGetSegment(SegmentId(i));
if (segment == nullptr) return total;
total += unboundAs<size_t>(segment->getSize() / WORDS);
}
}
SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
if (segment0.getArray() == nullptr) { if (segment0.getArray() == nullptr) {
...@@ -165,6 +175,24 @@ BuilderArena::BuilderArena(MessageBuilder* message, ...@@ -165,6 +175,24 @@ BuilderArena::BuilderArena(MessageBuilder* message,
BuilderArena::~BuilderArena() noexcept(false) {} BuilderArena::~BuilderArena() noexcept(false) {}
size_t BuilderArena::sizeInWords() {
KJ_IF_MAYBE(segmentState, moreSegments) {
size_t total = segment0.currentlyAllocated().size();
for (auto& builder: segmentState->get()->builders) {
total += builder->currentlyAllocated().size();
}
return total;
} else {
if (segment0.getArena() == nullptr) {
// We haven't actually allocated any segments yet.
return 0;
} else {
// We have only one segment so far.
return segment0.currentlyAllocated().size();
}
}
}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) { SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
// This method is allowed to fail if the segment ID is not valid. // This method is allowed to fail if the segment ID is not valid.
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
......
...@@ -230,6 +230,8 @@ public: ...@@ -230,6 +230,8 @@ public:
~ReaderArena() noexcept(false); ~ReaderArena() noexcept(false);
KJ_DISALLOW_COPY(ReaderArena); KJ_DISALLOW_COPY(ReaderArena);
size_t sizeInWords();
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportReadLimitReached() override; void reportReadLimitReached() override;
...@@ -264,6 +266,8 @@ public: ...@@ -264,6 +266,8 @@ public:
~BuilderArena() noexcept(false); ~BuilderArena() noexcept(false);
KJ_DISALLOW_COPY(BuilderArena); KJ_DISALLOW_COPY(BuilderArena);
size_t sizeInWords();
inline SegmentBuilder* getRootSegment() { return &segment0; } inline SegmentBuilder* getRootSegment() { return &segment0; }
kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput(); kj::ArrayPtr<const kj::ArrayPtr<const word>> getSegmentsForOutput();
......
...@@ -80,6 +80,9 @@ bool MessageReader::isCanonical() { ...@@ -80,6 +80,9 @@ bool MessageReader::isCanonical() {
return rootIsCanonical && allWordsConsumed; return rootIsCanonical && allWordsConsumed;
} }
size_t MessageReader::sizeInWords() {
return arena()->sizeInWords();
}
AnyPointer::Reader MessageReader::getRootInternal() { AnyPointer::Reader MessageReader::getRootInternal() {
if (!allocatedArena) { if (!allocatedArena) {
...@@ -178,6 +181,10 @@ bool MessageBuilder::isCanonical() { ...@@ -178,6 +181,10 @@ bool MessageBuilder::isCanonical() {
.isCanonical(&readHead); .isCanonical(&readHead);
} }
size_t MessageBuilder::sizeInWords() {
return arena()->sizeInWords();
}
kj::Own<_::CapTableBuilder> MessageBuilder::releaseBuiltinCapTable() { kj::Own<_::CapTableBuilder> MessageBuilder::releaseBuiltinCapTable() {
return arena()->releaseLocalCapTable(); return arena()->releaseLocalCapTable();
} }
......
...@@ -127,6 +127,9 @@ public: ...@@ -127,6 +127,9 @@ public:
bool isCanonical(); bool isCanonical();
// Returns whether the message encoded in the reader is in canonical form. // Returns whether the message encoded in the reader is in canonical form.
size_t sizeInWords();
// Add up the size of all segments.
private: private:
ReaderOptions options; ReaderOptions options;
...@@ -238,6 +241,9 @@ public: ...@@ -238,6 +241,9 @@ public:
bool isCanonical(); bool isCanonical();
// Check whether the message builder is in canonical form // Check whether the message builder is in canonical form
size_t sizeInWords();
// Add up the allocated space from all segments.
private: private:
void* arenaSpace[22]; void* arenaSpace[22];
// Space in which we can construct a BuilderArena. We don't use BuilderArena directly here // Space in which we can construct a BuilderArena. We don't use BuilderArena directly here
......
...@@ -246,6 +246,10 @@ public: ...@@ -246,6 +246,10 @@ public:
return message.getRoot<AnyPointer>(); return message.getRoot<AnyPointer>();
} }
size_t sizeInWords() override {
return data.size();
}
kj::Array<word> data; kj::Array<word> data;
FlatArrayMessageReader message; FlatArrayMessageReader message;
}; };
...@@ -291,6 +295,10 @@ public: ...@@ -291,6 +295,10 @@ public:
}))); })));
} }
size_t sizeInWords() override {
return message.sizeInWords();
}
private: private:
ConnectionImpl& connection; ConnectionImpl& connection;
MallocMessageBuilder message; MallocMessageBuilder message;
......
...@@ -128,6 +128,10 @@ public: ...@@ -128,6 +128,10 @@ public:
.eagerlyEvaluate(nullptr); .eagerlyEvaluate(nullptr);
} }
size_t sizeInWords() override {
return message.sizeInWords();
}
private: private:
TwoPartyVatNetwork& network; TwoPartyVatNetwork& network;
MallocMessageBuilder message; MallocMessageBuilder message;
...@@ -153,6 +157,10 @@ public: ...@@ -153,6 +157,10 @@ public:
return fds; return fds;
} }
size_t sizeInWords() override {
return message->sizeInWords();
}
private: private:
kj::Own<MessageReader> message; kj::Own<MessageReader> message;
kj::Array<kj::AutoCloseFd> fdSpace; kj::Array<kj::AutoCloseFd> fdSpace;
......
...@@ -1823,7 +1823,7 @@ private: ...@@ -1823,7 +1823,7 @@ private:
answerId(answerId), answerId(answerId),
interfaceId(interfaceId), interfaceId(interfaceId),
methodId(methodId), methodId(methodId),
requestSize(request->getBody().targetSize().wordCount), requestSize(request->sizeInWords()),
request(kj::mv(request)), request(kj::mv(request)),
paramsCapTable(kj::mv(capTableArray)), paramsCapTable(kj::mv(capTableArray)),
params(paramsCapTable.imbue(params)), params(paramsCapTable.imbue(params)),
......
...@@ -314,6 +314,11 @@ public: ...@@ -314,6 +314,11 @@ public:
virtual void send() = 0; virtual void send() = 0;
// Send the message, or at least put it in a queue to be sent later. Note that the builder // Send the message, or at least put it in a queue to be sent later. Note that the builder
// returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed. // returned by `getBody()` remains valid at least until the `OutgoingRpcMessage` is destroyed.
virtual size_t sizeInWords() = 0;
// Get the total size of the message, for flow control purposes. Although the caller could
// also call getBody().targetSize(), doing that would walk the message tree, whereas typical
// implementations can compute the size more cheaply by summing segment sizes.
}; };
class IncomingRpcMessage { class IncomingRpcMessage {
...@@ -331,6 +336,11 @@ public: ...@@ -331,6 +336,11 @@ public:
// should be careful to check if an FD was already consumed by comparing the slot with `nullptr`. // should be careful to check if an FD was already consumed by comparing the slot with `nullptr`.
// (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only // (We don't use Maybe here because moving from a Maybe doesn't make it null, so it would only
// add confusion. Moving from an AutoCloseFd does in fact make it null.) // add confusion. Moving from an AutoCloseFd does in fact make it null.)
virtual size_t sizeInWords() = 0;
// Get the total size of the message, for flow control purposes. Although the caller could
// also call getBody().targetSize(), doing that would walk the message tree, whereas typical
// implementations can compute the size more cheaply by summing segment sizes.
}; };
template <typename VatId, typename ProvisionId, typename RecipientId, template <typename VatId, typename ProvisionId, typename RecipientId,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment