Commit 2bcb6e69 authored by Kenton Varda's avatar Kenton Varda

Simplify reuse strategy, add reuse comparisons to benchmarks.

parent 15aa868b
......@@ -42,16 +42,6 @@ ReaderArena::ReaderArena(MessageReader* message)
ReaderArena::~ReaderArena() {}
void ReaderArena::reset() {
readLimiter.reset(message->getOptions().traversalLimitInWords * WORDS);
ignoreErrors = false;
segment0.~SegmentReader();
new(&segment0) SegmentReader(this, SegmentId(0), this->message->getSegment(0), &readLimiter);
// TODO: Reuse the rest of the SegmentReaders?
moreSegments = nullptr;
}
SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
if (id == SegmentId(0)) {
if (segment0.getArray() == nullptr) {
......@@ -110,16 +100,6 @@ BuilderArena::BuilderArena(MessageBuilder* message)
: message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BuilderArena::~BuilderArena() {}
void BuilderArena::reset() {
segment0.reset();
if (moreSegments != nullptr) {
// TODO: As mentioned in another TODO below, only the last segment will only be reused.
for (auto& segment: moreSegments->builders) {
segment->reset();
}
}
}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
// This method is allowed to crash if the segment ID is not valid.
if (id == SegmentId(0)) {
......
......@@ -164,8 +164,6 @@ public:
~ReaderArena();
CAPNPROTO_DISALLOW_COPY(ReaderArena);
void reset();
// implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override;
......@@ -189,9 +187,6 @@ public:
~BuilderArena();
CAPNPROTO_DISALLOW_COPY(BuilderArena);
void reset();
// Resets all the segments to be empty, so that a new message can be started.
SegmentBuilder* getSegment(SegmentId id);
// Get the segment with the given id. Crashes or throws an exception if no such segment exists.
......
......@@ -185,83 +185,130 @@ public:
// =======================================================================================
template <typename TestCase>
struct NoScratch {
struct ScratchSpace {};
class MessageReader: public StreamFdMessageReader {
public:
inline MessageReader(int fd, ScratchSpace& scratch)
: StreamFdMessageReader(fd) {}
};
class MessageBuilder: public MallocMessageBuilder {
public:
inline MessageBuilder(ScratchSpace& scratch): MallocMessageBuilder() {}
};
};
template <size_t size>
struct UseScratch {
struct ScratchSpace {
word words[size];
};
class MessageReader: public StreamFdMessageReader {
public:
inline MessageReader(int fd, ScratchSpace& scratch)
: StreamFdMessageReader(fd, ReaderOptions(), arrayPtr(scratch.words, size)) {}
};
class MessageBuilder: public MallocMessageBuilder {
public:
inline MessageBuilder(ScratchSpace& scratch)
: MallocMessageBuilder(arrayPtr(scratch.words, size)) {}
};
};
// =======================================================================================
template <typename TestCase, typename ReuseStrategy>
void syncClient(int inputFd, int outputFd, uint64_t iters) {
MallocMessageBuilder builder;
// StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT);
typename ReuseStrategy::ScratchSpace scratch;
for (; iters > 0; --iters) {
typename TestCase::Expectation expected = TestCase::setupRequest(
builder.initRoot<typename TestCase::Request>());
typename TestCase::Expectation expected;
{
typename ReuseStrategy::MessageBuilder builder(scratch);
expected = TestCase::setupRequest(
builder.template initRoot<typename TestCase::Request>());
writeMessageToFd(outputFd, builder);
}
// reader.readNext();
StreamFdMessageReader reader(inputFd);
if (!TestCase::checkResponse(reader.getRoot<typename TestCase::Response>(), expected)) {
{
typename ReuseStrategy::MessageReader reader(inputFd, scratch);
if (!TestCase::checkResponse(
reader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response.");
}
}
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClientSender(int outputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) {
MallocMessageBuilder builder;
typename ReuseStrategy::ScratchSpace scratch;
for (; iters > 0; --iters) {
expectations->post(TestCase::setupRequest(builder.initRoot<typename TestCase::Request>()));
typename ReuseStrategy::MessageBuilder builder(scratch);
expectations->post(TestCase::setupRequest(
builder.template initRoot<typename TestCase::Request>()));
writeMessageToFd(outputFd, builder);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClientReceiver(int inputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) {
StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT);
typename ReuseStrategy::ScratchSpace scratch;
for (; iters > 0; --iters) {
typename TestCase::Expectation expected = expectations->next();
reader.readNext();
if (!TestCase::checkResponse(reader.getRoot<typename TestCase::Response>(), expected)) {
typename ReuseStrategy::MessageReader reader(inputFd, scratch);
if (!TestCase::checkResponse(
reader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response.");
}
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClient(int inputFd, int outputFd, uint64_t iters) {
ProducerConsumerQueue<typename TestCase::Expectation> expectations;
std::thread receiverThread(asyncClientReceiver<TestCase>, inputFd, &expectations, iters);
asyncClientSender<TestCase>(outputFd, &expectations, iters);
std::thread receiverThread(
asyncClientReceiver<TestCase, ReuseStrategy>, inputFd, &expectations, iters);
asyncClientSender<TestCase, ReuseStrategy>(outputFd, &expectations, iters);
receiverThread.join();
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void server(int inputFd, int outputFd, uint64_t iters) {
StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT);
MallocMessageBuilder builder;
typename ReuseStrategy::ScratchSpace builderScratch;
typename ReuseStrategy::ScratchSpace readerScratch;
for (; iters > 0; --iters) {
reader.readNext();
// StreamFdMessageReader reader(inputFd);
TestCase::handleRequest(reader.getRoot<typename TestCase::Request>(),
builder.initRoot<typename TestCase::Response>());
typename ReuseStrategy::MessageBuilder builder(builderScratch);
typename ReuseStrategy::MessageReader reader(inputFd, readerScratch);
TestCase::handleRequest(reader.template getRoot<typename TestCase::Request>(),
builder.template initRoot<typename TestCase::Response>());
writeMessageToFd(outputFd, builder);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void passByObject(uint64_t iters) {
MallocMessageBuilder requestMessage;
MallocMessageBuilder responseMessage;
typename ReuseStrategy::ScratchSpace requestScratch;
typename ReuseStrategy::ScratchSpace responseScratch;
for (; iters > 0; --iters) {
auto request = requestMessage.initRoot<typename TestCase::Request>();
typename ReuseStrategy::MessageBuilder requestMessage(requestScratch);
auto request = requestMessage.template initRoot<typename TestCase::Request>();
typename TestCase::Expectation expected = TestCase::setupRequest(request);
auto response = responseMessage.initRoot<typename TestCase::Response>();
typename ReuseStrategy::MessageBuilder responseMessage(responseScratch);
auto response = responseMessage.template initRoot<typename TestCase::Response>();
TestCase::handleRequest(request.asReader(), response);
if (!TestCase::checkResponse(response.asReader(), expected)) {
......@@ -270,29 +317,32 @@ void passByObject(uint64_t iters) {
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void passByBytes(uint64_t iters) {
MallocMessageBuilder requestBuilder;
MallocMessageBuilder responseBuilder;
typename ReuseStrategy::ScratchSpace requestScratch;
typename ReuseStrategy::ScratchSpace responseScratch;
for (; iters > 0; --iters) {
typename ReuseStrategy::MessageBuilder requestBuilder(requestScratch);
typename TestCase::Expectation expected = TestCase::setupRequest(
requestBuilder.initRoot<typename TestCase::Request>());
requestBuilder.template initRoot<typename TestCase::Request>());
Array<word> requestBytes = messageToFlatArray(requestBuilder);
FlatArrayMessageReader requestReader(requestBytes.asPtr());
TestCase::handleRequest(requestReader.getRoot<typename TestCase::Request>(),
responseBuilder.initRoot<typename TestCase::Response>());
typename ReuseStrategy::MessageBuilder responseBuilder(responseScratch);
TestCase::handleRequest(requestReader.template getRoot<typename TestCase::Request>(),
responseBuilder.template initRoot<typename TestCase::Response>());
Array<word> responseBytes = messageToFlatArray(responseBuilder);
FlatArrayMessageReader responseReader(responseBytes.asPtr());
if (!TestCase::checkResponse(responseReader.getRoot<typename TestCase::Response>(), expected)) {
if (!TestCase::checkResponse(
responseReader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response.");
}
}
}
template <typename TestCase, typename Func>
template <typename TestCase, typename ReuseStrategy, typename Func>
void passByPipe(Func&& clientFunc, uint64_t iters) {
int clientToServer[2];
int serverToClient[2];
......@@ -312,7 +362,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
close(clientToServer[1]);
close(serverToClient[0]);
server<TestCase>(clientToServer[0], serverToClient[1], iters);
server<TestCase, ReuseStrategy>(clientToServer[0], serverToClient[1], iters);
int status;
if (waitpid(child, &status, 0) != child) {
......@@ -324,32 +374,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
}
}
template <typename ReuseStrategy>
void doBenchmark(const std::string& mode, uint64_t iters) {
if (mode == "client") {
syncClient<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
syncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
asyncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
exit(1);
}
}
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "USAGE: " << argv[0] << " MODE ITERATION_COUNT" << std::endl;
if (argc != 4) {
std::cerr << "USAGE: " << argv[0] << " MODE REUSE ITERATION_COUNT" << std::endl;
return 1;
}
uint64_t iters = strtoull(argv[2], nullptr, 0);
uint64_t iters = strtoull(argv[3], nullptr, 0);
srand(123);
std::cerr << "Doing " << iters << " iterations..." << std::endl;
std::string mode = argv[1];
if (mode == "client") {
syncClient<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase>(syncClient<ExpressionTestCase>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase>(asyncClient<ExpressionTestCase>, iters);
std::string reuse = argv[2];
if (reuse == "reuse") {
doBenchmark<UseScratch<1024>>(argv[1], iters);
} else if (reuse == "no-reuse") {
doBenchmark<NoScratch>(argv[1], iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
std::cerr << "Unknown reuse mode: " << reuse << std::endl;
return 1;
}
......
......@@ -183,7 +183,49 @@ public:
// =======================================================================================
void writeProtoFast(const google::protobuf::MessageLite& message,
struct SingleUseMessages {
template <typename MessageType>
struct Message {
struct Reusable {};
struct SingleUse: public MessageType {
inline SingleUse(Reusable&) {}
};
};
struct ReusableString {};
struct SingleUseString: std::string {
inline SingleUseString(ReusableString&) {}
};
template <typename MessageType>
static inline void doneWith(MessageType& message) {
// Don't clear -- single-use.
}
};
struct ReusableMessages {
template <typename MessageType>
struct Message {
struct Reusable: public MessageType {};
typedef MessageType& SingleUse;
};
typedef std::string ReusableString;
typedef std::string& SingleUseString;
template <typename MessageType>
static inline void doneWith(MessageType& message) {
message.Clear();
}
};
// =======================================================================================
// The protobuf Java library defines a format for writing multiple protobufs to a stream, in which
// each message is prefixed by a varint size. This was never added to the C++ library. It's easy
// to do naively, but tricky to implement without accidentally losing various optimizations. These
// two functions should be optimal.
void writeDelimited(const google::protobuf::MessageLite& message,
google::protobuf::io::FileOutputStream* rawOutput) {
google::protobuf::io::CodedOutputStream output(rawOutput);
const int size = message.ByteSize();
......@@ -199,7 +241,7 @@ void writeProtoFast(const google::protobuf::MessageLite& message,
}
}
void readProtoFast(google::protobuf::io::ZeroCopyInputStream* rawInput,
void readDelimited(google::protobuf::io::ZeroCopyInputStream* rawInput,
google::protobuf::MessageLite* message) {
google::protobuf::io::CodedInputStream input(rawInput);
uint32_t size;
......@@ -217,139 +259,160 @@ void readProtoFast(google::protobuf::io::ZeroCopyInputStream* rawInput,
input.PopLimit(limit);
}
template <typename TestCase>
// =======================================================================================
#define REUSABLE(type) \
typename ReuseStrategy::template Message<typename TestCase::type>::Reusable
#define SINGLE_USE(type) \
typename ReuseStrategy::template Message<typename TestCase::type>::SingleUse
template <typename TestCase, typename ReuseStrategy>
void syncClient(int inputFd, int outputFd, uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd);
google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Request request;
typename TestCase::Response response;
REUSABLE(Request) reusableRequest;
REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&request);
writeProtoFast(request, &output);
writeDelimited(request, &output);
if (!output.Flush()) throw OsException(output.GetErrno());
request.Clear();
ReuseStrategy::doneWith(request);
// std::cerr << "client: wait" << std::endl;
readProtoFast(&input, &response);
SINGLE_USE(Response) response(reusableResponse);
readDelimited(&input, &response);
if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response.");
}
response.Clear();
ReuseStrategy::doneWith(response);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClientSender(int outputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd);
typename TestCase::Request request;
REUSABLE(Request) reusableRequest;
for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
expectations->post(TestCase::setupRequest(&request));
writeProtoFast(request, &output);
request.Clear();
writeDelimited(request, &output);
ReuseStrategy::doneWith(request);
}
if (!output.Flush()) throw OsException(output.GetErrno());
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClientReceiver(int inputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) {
google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Response response;
REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) {
typename TestCase::Expectation expected = expectations->next();
readProtoFast(&input, &response);
SINGLE_USE(Response) response(reusableResponse);
readDelimited(&input, &response);
if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response.");
}
response.Clear();
ReuseStrategy::doneWith(response);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void asyncClient(int inputFd, int outputFd, uint64_t iters) {
ProducerConsumerQueue<typename TestCase::Expectation> expectations;
std::thread receiverThread(asyncClientReceiver<TestCase>, inputFd, &expectations, iters);
asyncClientSender<TestCase>(outputFd, &expectations, iters);
std::thread receiverThread(
asyncClientReceiver<TestCase, ReuseStrategy>, inputFd, &expectations, iters);
asyncClientSender<TestCase, ReuseStrategy>(outputFd, &expectations, iters);
receiverThread.join();
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void server(int inputFd, int outputFd, uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd);
google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Request request;
typename TestCase::Response response;
REUSABLE(Request) reusableRequest;
REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) {
readProtoFast(&input, &request);
SINGLE_USE(Request) request(reusableRequest);
readDelimited(&input, &request);
SINGLE_USE(Response) response(reusableResponse);
TestCase::handleRequest(request, &response);
request.Clear();
ReuseStrategy::doneWith(request);
writeProtoFast(response, &output);
writeDelimited(response, &output);
if (!output.Flush()) throw std::logic_error("Write failed.");
response.Clear();
ReuseStrategy::doneWith(response);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void passByObject(uint64_t iters) {
typename TestCase::Request request;
typename TestCase::Response response;
REUSABLE(Request) reusableRequest;
REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&request);
SINGLE_USE(Response) response(reusableResponse);
TestCase::handleRequest(request, &response);
request.Clear();
ReuseStrategy::doneWith(request);
if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response.");
}
response.Clear();
ReuseStrategy::doneWith(response);
}
}
template <typename TestCase>
template <typename TestCase, typename ReuseStrategy>
void passByBytes(uint64_t iters) {
typename TestCase::Request clientRequest;
typename TestCase::Request serverRequest;
typename TestCase::Response serverResponse;
typename TestCase::Response clientResponse;
std::string requestString, responseString;
REUSABLE(Request) reusableClientRequest;
REUSABLE(Request) reusableServerRequest;
REUSABLE(Response) reusableServerResponse;
REUSABLE(Response) reusableClientResponse;
typename ReuseStrategy::ReusableString reusableRequestString, reusableResponseString;
for (; iters > 0; --iters) {
SINGLE_USE(Request) clientRequest(reusableClientRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&clientRequest);
typename ReuseStrategy::SingleUseString requestString(reusableRequestString);
clientRequest.SerializePartialToString(&requestString);
clientRequest.Clear();
ReuseStrategy::doneWith(clientRequest);
SINGLE_USE(Request) serverRequest(reusableServerRequest);
serverRequest.ParsePartialFromString(requestString);
requestString.clear();
SINGLE_USE(Response) serverResponse(reusableServerResponse);
TestCase::handleRequest(serverRequest, &serverResponse);
serverRequest.Clear();
ReuseStrategy::doneWith(serverRequest);
typename ReuseStrategy::SingleUseString responseString(reusableResponseString);
serverResponse.SerializePartialToString(&responseString);
serverResponse.Clear();
ReuseStrategy::doneWith(serverResponse);
SINGLE_USE(Response) clientResponse(reusableClientResponse);
clientResponse.ParsePartialFromString(responseString);
responseString.clear();
if (!TestCase::checkResponse(clientResponse, expected)) {
throw std::logic_error("Incorrect response.");
}
clientResponse.Clear();
ReuseStrategy::doneWith(clientResponse);
}
}
template <typename TestCase, typename Func>
template <typename TestCase, typename ReuseStrategy, typename Func>
void passByPipe(Func&& clientFunc, uint64_t iters) {
int clientToServer[2];
int serverToClient[2];
......@@ -369,7 +432,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
close(clientToServer[1]);
close(serverToClient[0]);
server<TestCase>(clientToServer[0], serverToClient[1], iters);
server<TestCase, ReuseStrategy>(clientToServer[0], serverToClient[1], iters);
int status;
if (waitpid(child, &status, 0) != child) {
......@@ -381,32 +444,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
}
}
template <typename ReuseStrategy>
void doBenchmark(const std::string& mode, uint64_t iters) {
if (mode == "client") {
syncClient<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
syncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
asyncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
exit(1);
}
}
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "USAGE: " << argv[0] << " MODE ITERATION_COUNT" << std::endl;
if (argc != 4) {
std::cerr << "USAGE: " << argv[0] << " MODE REUSE ITERATION_COUNT" << std::endl;
return 1;
}
uint64_t iters = strtoull(argv[2], nullptr, 0);
uint64_t iters = strtoull(argv[3], nullptr, 0);
srand(123);
std::cerr << "Doing " << iters << " iterations..." << std::endl;
std::string mode = argv[1];
if (mode == "client") {
syncClient<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase>(syncClient<ExpressionTestCase>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase>(asyncClient<ExpressionTestCase>, iters);
std::string reuse = argv[2];
if (reuse == "reuse") {
doBenchmark<ReusableMessages>(argv[1], iters);
} else if (reuse == "no-reuse") {
doBenchmark<SingleUseMessages>(argv[1], iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
std::cerr << "Unknown reuse mode: " << reuse << std::endl;
return 1;
}
......
......@@ -38,10 +38,6 @@ MessageReader::~MessageReader() {
}
}
void MessageReader::reset() {
if (allocatedArena) arena()->reset();
}
internal::StructReader MessageReader::getRoot(const word* defaultValue) {
if (!allocatedArena) {
static_assert(sizeof(internal::ReaderArena) <= sizeof(arenaSpace),
......@@ -71,14 +67,15 @@ MessageBuilder::~MessageBuilder() {
}
}
internal::SegmentBuilder* MessageBuilder::allocateRootSegment() {
if (!allocatedArena) {
internal::SegmentBuilder* MessageBuilder::getRootSegment() {
if (allocatedArena) {
return arena()->getSegment(SegmentId(0));
} else {
static_assert(sizeof(internal::BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it. This will break "
"ABI compatibility.");
new(arena()) internal::BuilderArena(this);
allocatedArena = true;
}
WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE;
internal::SegmentBuilder* segment = arena()->getSegmentWithAvailable(refSize);
......@@ -88,18 +85,17 @@ internal::SegmentBuilder* MessageBuilder::allocateRootSegment() {
CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS),
"First allocated word of new arena was not the first word in its segment.");
return segment;
}
}
internal::StructBuilder MessageBuilder::initRoot(const word* defaultValue) {
if (allocatedArena) arena()->reset();
internal::SegmentBuilder* rootSegment = allocateRootSegment();
internal::SegmentBuilder* rootSegment = getRootSegment();
return internal::StructBuilder::initRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
internal::StructBuilder MessageBuilder::getRoot(const word* defaultValue) {
internal::SegmentBuilder* rootSegment = allocatedArena ?
arena()->getSegment(SegmentId(0)) : allocateRootSegment();
internal::SegmentBuilder* rootSegment = getRootSegment();
return internal::StructBuilder::getRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
}
......@@ -200,10 +196,15 @@ struct MallocMessageBuilder::MoreSegments {
MallocMessageBuilder::MallocMessageBuilder(
uint firstSegmentWords, AllocationStrategy allocationStrategy)
: nextSize(firstSegmentWords), allocationStrategy(allocationStrategy),
firstSegment(nullptr) {}
ownFirstSegment(true), firstSegment(nullptr) {}
MallocMessageBuilder::MallocMessageBuilder(
ArrayPtr<word> firstSegment, AllocationStrategy allocationStrategy)
: nextSize(firstSegment.size()), allocationStrategy(allocationStrategy),
ownFirstSegment(false), firstSegment(firstSegment.begin()) {}
MallocMessageBuilder::~MallocMessageBuilder() {
free(firstSegment);
if (ownFirstSegment) free(firstSegment);
if (moreSegments != nullptr) {
for (void* ptr: moreSegments->segments) {
free(ptr);
......@@ -212,6 +213,19 @@ MallocMessageBuilder::~MallocMessageBuilder() {
}
ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) {
if (!ownFirstSegment) {
ArrayPtr<word> result = arrayPtr(reinterpret_cast<word*>(firstSegment), nextSize);
firstSegment = nullptr;
ownFirstSegment = true;
if (result.size() >= minimumSize) {
memset(result.begin(), 0, result.size() * sizeof(word));
return result;
}
// If the provided first segment wasn't big enough, we discard it and proceed to allocate
// our own. This never happens in practice since minimumSize is always 1 for the first
// segment.
}
uint size = std::max(minimumSize, nextSize);
void* result = calloc(size, sizeof(word));
......
......@@ -133,14 +133,6 @@ public:
template <typename RootType>
typename RootType::Reader getRoot();
protected:
void reset();
// Clear the cached segment table so that the reader can be reused to read another message.
// reset() may call getSegment() again before returning, so you must arrange for the new segment
// set to be active *before* calling this.
//
// This invalidates any Readers currently pointing into this message.
private:
ReaderOptions options;
......@@ -182,7 +174,7 @@ private:
bool allocatedArena = false;
internal::BuilderArena* arena() { return reinterpret_cast<internal::BuilderArena*>(arenaSpace); }
internal::SegmentBuilder* allocateRootSegment();
internal::SegmentBuilder* getRootSegment();
internal::StructBuilder initRoot(const word* defaultValue);
internal::StructBuilder getRoot(const word* defaultValue);
};
......@@ -232,7 +224,7 @@ private:
ArrayPtr<const ArrayPtr<const word>> segments;
};
enum class AllocationStrategy {
enum class AllocationStrategy: uint8_t {
FIXED_SIZE,
// The builder will prefer to allocate the same amount of space for each segment with no
// heuristic growth. It will still allocate larger segments when the preferred size is too small
......@@ -257,7 +249,7 @@ class MallocMessageBuilder: public MessageBuilder {
// a specific location in memory.
public:
MallocMessageBuilder(uint firstSegmentWords = 1024,
explicit MallocMessageBuilder(uint firstSegmentWords = 1024,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// Creates a BuilderContext which allocates at least the given number of words for the first
// segment, and then uses the given strategy to decide how much to allocate for subsequent
......@@ -271,6 +263,12 @@ public:
// The defaults have been chosen to be reasonable for most people, so don't change them unless you
// have reason to believe you need to.
explicit MallocMessageBuilder(ArrayPtr<word> firstSegment,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// This version always returns the given array for the first segment, and then proceeds with the
// allocation strategy. This is useful for optimization when building lots of small messages in
// a tight loop: you can reuse the space for the first segment.
CAPNPROTO_DISALLOW_COPY(MallocMessageBuilder);
virtual ~MallocMessageBuilder();
......@@ -280,6 +278,7 @@ private:
uint nextSize;
AllocationStrategy allocationStrategy;
bool ownFirstSegment;
void* firstSegment;
struct MoreSegments;
......
......@@ -83,7 +83,7 @@ TEST(Serialize, FlatArrayOddSegmentCount) {
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, FlatArrayEventSegmentCount) {
TEST(Serialize, FlatArrayEvenSegmentCount) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
......@@ -95,25 +95,24 @@ TEST(Serialize, FlatArrayEventSegmentCount) {
class TestInputStream: public InputStream {
public:
TestInputStream(ArrayPtr<const word> data)
TestInputStream(ArrayPtr<const word> data, bool lazy)
: pos(reinterpret_cast<const char*>(data.begin())),
end(reinterpret_cast<const char*>(data.end())) {}
end(reinterpret_cast<const char*>(data.end())),
lazy(lazy) {}
~TestInputStream() {}
bool read(void* buffer, size_t size) override {
if (size_t(end - pos) < size) {
ADD_FAILURE() << "Overran end of stream.";
return false;
} else {
memcpy(buffer, pos, size);
pos += size;
return true;
}
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override {
CAPNPROTO_ASSERT(maxBytes <= size_t(end - pos), "Overran end of stream.");
size_t amount = lazy ? minBytes : maxBytes;
memcpy(buffer, pos, amount);
pos += amount;
return amount;
}
private:
const char* pos;
const char* end;
bool lazy;
};
TEST(Serialize, InputStream) {
......@@ -122,162 +121,81 @@ TEST(Serialize, InputStream) {
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamLazy) {
TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamOddSegmentCount) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamOddSegmentCountLazy) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamEventSegmentCount) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamEventSegmentCountLazy) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
class TestInputFile: public InputFile {
public:
TestInputFile(ArrayPtr<const word> data)
: begin(reinterpret_cast<const char*>(data.begin())),
size_(data.size() * sizeof(word)) {}
~TestInputFile() {}
bool read(size_t offset, void* buffer, size_t size) override {
if (size_ < offset + size) {
ADD_FAILURE() << "Overran end of file.";
return false;
} else {
memcpy(buffer, begin + offset, size);
return true;
}
}
private:
const char* begin;
size_t size_;
};
TEST(Serialize, InputFile) {
TEST(Serialize, InputStreamScratchSpace) {
TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER);
word scratch[4096];
TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(stream, ReaderOptions(), ArrayPtr<word>(scratch, 4096));
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputFileLazy) {
TEST(Serialize, InputStreamLazy) {
TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY);
TestInputStream stream(serialized.asPtr(), true);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputFileOddSegmentCount) {
TEST(Serialize, InputStreamOddSegmentCount) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER);
TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputFileOddSegmentCountLazy) {
TEST(Serialize, InputStreamOddSegmentCountLazy) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY);
TestInputStream stream(serialized.asPtr(), true);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputFileEventSegmentCount) {
TEST(Serialize, InputStreamEvenSegmentCount) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER);
TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputFileEventSegmentCountLazy) {
TEST(Serialize, InputStreamEvenSegmentCountLazy) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr());
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY);
TestInputStream stream(serialized.asPtr(), true);
InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
......@@ -324,7 +242,7 @@ TEST(Serialize, WriteMessageOddSegmentCount) {
EXPECT_TRUE(output.dataEquals(serialized.asPtr()));
}
TEST(Serialize, WriteMessageEventSegmentCount) {
TEST(Serialize, WriteMessageEvenSegmentCount) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
......@@ -363,22 +281,10 @@ TEST(Serialize, FileDescriptors) {
checkTestMessage(reader.getRoot<TestAllTypes>());
}
{
FileFdMessageReader reader(tmpfile.get(), 0);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
size_t secondStart = lseek(tmpfile, 0, SEEK_CUR);
{
StreamFdMessageReader reader(tmpfile.get());
EXPECT_EQ("second message in file", reader.getRoot<TestAllTypes>().getTextField());
}
{
FileFdMessageReader reader(move(tmpfile), secondStart);
EXPECT_EQ("second message in file", reader.getRoot<TestAllTypes>().getTextField());
}
}
// TODO: Test error cases.
......
......@@ -26,6 +26,8 @@
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <string>
#include <sys/uio.h>
namespace capnproto {
......@@ -127,179 +129,100 @@ Array<word> messageToFlatArray(ArrayPtr<const ArrayPtr<const word>> segments) {
// =======================================================================================
InputStream::~InputStream() {}
InputFile::~InputFile() {}
OutputStream::~OutputStream() {}
// -------------------------------------------------------------------
InputStreamMessageReader::InputStreamMessageReader(
InputStream* inputStream, ReaderOptions options, InputStrategy inputStrategy)
: MessageReader(options), inputStream(inputStream), inputStrategy(inputStrategy),
segmentsReadSoFar(0) {
switch (inputStrategy) {
case InputStrategy::EAGER:
case InputStrategy::LAZY:
readNextInternal();
break;
case InputStrategy::EAGER_WAIT_FOR_READ_NEXT:
case InputStrategy::LAZY_WAIT_FOR_READ_NEXT:
break;
void InputStream::skip(size_t bytes) {
char scratch[8192];
while (bytes > 0) {
size_t amount = std::min(bytes, sizeof(scratch));
bytes -= read(scratch, amount, amount);
}
}
void InputStreamMessageReader::readNext() {
bool needReset = false;
switch (inputStrategy) {
case InputStrategy::LAZY:
if (moreSegments != nullptr || segment0.size != 0) {
// Make sure we've finished reading the previous message.
// Note that this sort of defeats the purpose of lazy parsing. In theory we could be a
// little more efficient by reading into a stack-allocated scratch buffer rather than
// allocating space for the remaining segments, but people really shouldn't be using
// readNext() when lazy-parsing anyway.
getSegment(moreSegments.size());
}
// no break
case InputStrategy::EAGER:
needReset = true;
// TODO: Save moreSegments for reuse?
moreSegments = nullptr;
segmentsReadSoFar = 0;
segment0.size = 0;
break;
case InputStrategy::EAGER_WAIT_FOR_READ_NEXT:
this->inputStrategy = InputStrategy::EAGER;
break;
case InputStrategy::LAZY_WAIT_FOR_READ_NEXT:
this->inputStrategy = InputStrategy::LAZY;
break;
void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
for (auto piece: pieces) {
write(piece.begin(), piece.size());
}
if (inputStream != nullptr) {
readNextInternal();
}
if (needReset) reset();
}
void InputStreamMessageReader::readNextInternal() {
// -------------------------------------------------------------------
InputStreamMessageReader::InputStreamMessageReader(
InputStream& inputStream, ReaderOptions options, ArrayPtr<word> scratchSpace)
: MessageReader(options), inputStream(inputStream), readPos(nullptr) {
internal::WireValue<uint32_t> firstWord[2];
if (!inputStream->read(firstWord, sizeof(firstWord))) return;
inputStream.read(firstWord, sizeof(firstWord), sizeof(firstWord));
uint segmentCount = firstWord[0].get();
segment0.size = segmentCount == 0 ? 0 : firstWord[1].get();
uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get();
if (segmentCount > 1) {
internal::WireValue<uint32_t> sizes[segmentCount - 1];
if (!inputStream->read(sizes, sizeof(sizes))) return;
size_t totalWords = segment0Size;
moreSegments = newArray<LazySegment>(segmentCount - 1);
for (uint i = 1; i < segmentCount; i++) {
moreSegments[i - 1].size = sizes[i - 1].get();
// Read sizes for all segments except the first. Include padding if necessary.
internal::WireValue<uint32_t> moreSizes[segmentCount & ~1];
if (segmentCount > 1) {
inputStream.read(moreSizes, sizeof(moreSizes), sizeof(moreSizes));
for (uint i = 0; i < segmentCount - 1; i++) {
totalWords += moreSizes[i].get();
}
}
if (segmentCount % 2 == 0) {
// Read the padding.
uint32_t pad;
if (!inputStream->read(&pad, sizeof(pad))) return;
if (scratchSpace.size() < totalWords) {
// TODO: Consider allocating each segment as a separate chunk to reduce memory fragmentation.
ownedSpace = newArray<word>(totalWords);
scratchSpace = ownedSpace;
}
if (inputStrategy == InputStrategy::EAGER) {
getSegment(segmentCount - 1);
}
}
segment0 = scratchSpace.slice(0, segment0Size);
ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
if (id > moreSegments.size()) {
return nullptr;
}
if (segmentCount > 1) {
moreSegments = newArray<ArrayPtr<const word>>(segmentCount - 1);
size_t offset = segment0Size;
while (segmentsReadSoFar <= id && inputStream != nullptr) {
LazySegment& segment = segmentsReadSoFar == 0 ? segment0 : moreSegments[segmentsReadSoFar - 1];
if (segment.words.size() < segment.size) {
segment.words = newArray<word>(segment.size);
for (uint i = 0; i < segmentCount - 1; i++) {
uint segmentSize = moreSizes[i].get();
moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize);
offset += segmentSize;
}
if (!inputStream->read(segment.words.begin(), segment.size * sizeof(word))) {
// There was an error but no exception was thrown, so we're supposed to plod along with
// default values. Discard the broken stream.
inputStream = nullptr;
break;
}
++segmentsReadSoFar;
if (segmentCount == 1) {
inputStream.read(scratchSpace.begin(), totalWords * sizeof(word), totalWords * sizeof(word));
} else if (segmentCount > 1) {
readPos = reinterpret_cast<byte*>(scratchSpace.begin());
readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word));
}
LazySegment& segment = id == 0 ? segment0 : moreSegments[id - 1];
return segment.words.slice(0, segment.size);
}
// -------------------------------------------------------------------
InputFileMessageReader::InputFileMessageReader(
InputFile* inputFile, ReaderOptions options, InputStrategy inputStrategy)
: MessageReader(options), inputFile(inputFile) {
internal::WireValue<uint32_t> firstWord[2];
if (!inputFile->read(0, firstWord, sizeof(firstWord))) return;
uint segmentCount = firstWord[0].get();
segment0.size = segmentCount == 0 ? 0 : firstWord[1].get();
if (segmentCount > 1) {
internal::WireValue<uint32_t> sizes[segmentCount - 1];
if (!inputFile->read(sizeof(firstWord), sizes, sizeof(sizes))) return;
uint64_t offset = (segmentCount / 2 + 1) * sizeof(word);
segment0.offset = offset;
offset += segment0.size * sizeof(word);
moreSegments = newArray<LazySegment>(segmentCount - 1);
for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = sizes[i - 1].get();
moreSegments[i - 1].size = segmentSize;
moreSegments[i - 1].offset = offset;
offset += segmentSize * sizeof(word);
}
} else {
segment0.offset = sizeof(firstWord);
}
if (inputStrategy == InputStrategy::EAGER) {
for (uint i = 0; i < segmentCount; i++) {
getSegment(segmentCount);
}
inputFile = nullptr;
InputStreamMessageReader::~InputStreamMessageReader() {
if (readPos != nullptr) {
// Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
// valid.
const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
inputStream.skip(allEnd - readPos);
}
}
ArrayPtr<const word> InputFileMessageReader::getSegment(uint id) {
ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
if (id > moreSegments.size()) {
return nullptr;
}
LazySegment& segment = id == 0 ? segment0 : moreSegments[id - 1];
if (segment.words == nullptr && segment.size > 0 && inputFile != nullptr) {
Array<word> words = newArray<word>(segment.size);
ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1];
if (!inputFile->read(segment.offset, words.begin(), words.size() * sizeof(word))) {
// There was an error but no exception was thrown, so we're supposed to plod along with
// default values. Discard the broken stream.
inputFile = nullptr;
} else {
segment.words = move(words);
if (readPos != nullptr) {
// May need to lazily read more data.
const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end());
if (readPos < segmentEnd) {
// Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
// valid.
const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos);
}
}
return segment.words.asPtr();
return segment;
}
// -------------------------------------------------------------------
......@@ -316,27 +239,45 @@ void writeMessage(OutputStream& output, ArrayPtr<const ArrayPtr<const word>> seg
table[segments.size() + 1].set(0);
}
output.write(table, sizeof(table));
ArrayPtr<const byte> pieces[segments.size() + 1];
pieces[0] = arrayPtr(reinterpret_cast<byte*>(table), sizeof(table));
for (auto& segment: segments) {
output.write(segment.begin(), segment.size() * sizeof(word));
for (uint i = 0; i < segments.size(); i++) {
pieces[i + 1] = arrayPtr(reinterpret_cast<const byte*>(segments[i].begin()),
reinterpret_cast<const byte*>(segments[i].end()));
}
output.write(arrayPtr(pieces, segments.size() + 1));
}
// =======================================================================================
class OsException: public std::exception {
public:
OsException(int error): error(error) {}
OsException(const char* function, int error) {
char buffer[256];
message = function;
message += ": ";
message.append(strerror_r(error, buffer, sizeof(buffer)));
}
~OsException() noexcept {}
const char* what() const noexcept override {
// TODO: Use strerror_r or whatever for thread-safety. Ugh.
return strerror(error);
return message.c_str();
}
private:
int error;
std::string message;
};
class PrematureEofException: public std::exception {
public:
PrematureEofException() {}
~PrematureEofException() noexcept {}
const char* what() const noexcept override {
return "Stream ended prematurely.";
}
};
AutoCloseFd::~AutoCloseFd() {
......@@ -344,78 +285,95 @@ AutoCloseFd::~AutoCloseFd() {
if (std::uncaught_exception()) {
// TODO: Devise some way to report secondary errors during unwind.
} else {
throw OsException(errno);
throw OsException("close", errno);
}
}
}
FdInputStream::~FdInputStream() {}
bool FdInputStream::read(void* buffer, size_t size) {
char* pos = reinterpret_cast<char*>(buffer);
size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
byte* pos = reinterpret_cast<byte*>(buffer);
byte* min = pos + minBytes;
byte* max = pos + maxBytes;
while (size > 0) {
ssize_t n = ::read(fd, pos, size);
while (pos < min) {
ssize_t n = ::read(fd, pos, max - pos);
if (n <= 0) {
if (n < 0) {
// TODO: Use strerror_r or whatever for thread-safety. Ugh.
errorReporter->reportError(strerror(errno));
int error = errno;
if (error == EINTR) {
continue;
} else {
throw OsException("read", error);
}
} else if (n == 0) {
errorReporter->reportError("Stream ended prematurely.");
throw PrematureEofException();
}
return false;
}
pos += n;
size -= n;
}
return true;
return pos - reinterpret_cast<byte*>(buffer);
}
FdInputFile::~FdInputFile() {}
FdOutputStream::~FdOutputStream() {}
bool FdInputFile::read(size_t offset, void* buffer, size_t size) {
char* pos = reinterpret_cast<char*>(buffer);
offset += this->offset;
void FdOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer);
while (size > 0) {
ssize_t n = ::pread(fd, pos, size, offset);
ssize_t n = ::write(fd, pos, size);
if (n <= 0) {
if (n < 0) {
// TODO: Use strerror_r or whatever for thread-safety. Ugh.
errorReporter->reportError(strerror(errno));
} else if (n == 0) {
errorReporter->reportError("Stream ended prematurely.");
}
return false;
CAPNPROTO_ASSERT(n < 0, "write() returned zero.");
throw OsException("write", errno);
}
pos += n;
offset += n;
size -= n;
}
return true;
}
FdOutputStream::~FdOutputStream() {}
void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
struct iovec iov[pieces.size()];
for (uint i = 0; i < pieces.size(); i++) {
// writev() interface is not const-correct. :(
iov[i].iov_base = const_cast<byte*>(pieces[i].begin());
iov[i].iov_len = pieces[i].size();
}
void FdOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer);
struct iovec* current = iov;
struct iovec* end = iov + pieces.size();
// Make sure we don't do anything on an empty write.
while (current < end && current->iov_len == 0) {
++current;
}
while (current < end) {
ssize_t n = ::writev(fd, iov, end - current);
while (size > 0) {
ssize_t n = ::write(fd, pos, size);
if (n <= 0) {
throw OsException(n == 0 ? EIO : errno);
if (n <= 0) {
CAPNPROTO_ASSERT(n < 0, "write() returned zero.");
throw OsException("writev", errno);
}
}
while (static_cast<size_t>(n) >= current->iov_len) {
n -= current->iov_len;
++current;
}
if (n > 0) {
current->iov_base = reinterpret_cast<byte*>(current->iov_base) + n;
current->iov_len -= n;
}
pos += n;
size -= n;
}
}
StreamFdMessageReader::~StreamFdMessageReader() {}
FileFdMessageReader::~FileFdMessageReader() {}
void writeMessageToFd(int fd, ArrayPtr<const ArrayPtr<const word>> segments) {
FdOutputStream stream(fd);
......
......@@ -75,18 +75,25 @@ class InputStream {
public:
virtual ~InputStream();
virtual bool read(void* buffer, size_t size) = 0;
// Always reads the full size requested. Returns true if successful. May throw an exception
// on failure, or report the error through some side channel and return false.
};
class InputFile {
public:
virtual ~InputFile();
virtual bool read(size_t offset, void* buffer, size_t size) = 0;
// Always reads the full size requested. Returns true if successful. May throw an exception
// on failure, or report the error through some side channel and return false.
virtual size_t read(void* buffer, size_t minBytes, size_t maxBytes) = 0;
// Reads at least minBytes and at most maxBytes, copying them into the given buffer. Returns
// the size read. Throws an exception on errors.
//
// maxBytes is the number of bytes the caller really wants, but minBytes is the minimum amount
// needed by the caller before it can start doing useful processing. If the stream returns less
// than maxBytes, the caller will usually call read() again later to get the rest. Returning
// less than maxBytes is useful when it makes sense for the caller to parallelize processing
// with I/O.
//
// Cap'n Proto never asks for more bytes than it knows are part of the message. Therefore, if
// the InputStream happens to know that the stream will never reach maxBytes -- even if it has
// reached minBytes -- it should throw an exception to avoid wasting time processing an incomplete
// message. If it can't even reach minBytes, it MUST throw an exception, as the caller is not
// expected to understand how to deal with partial reads.
virtual void skip(size_t bytes);
// Skips past the given number of bytes, discarding them. The default implementation read()s
// into a scratch buffer.
};
class OutputStream {
......@@ -94,85 +101,34 @@ public:
virtual ~OutputStream();
virtual void write(const void* buffer, size_t size) = 0;
// Throws exception on error, or reports errors via some side channel and returns.
};
enum class InputStrategy {
EAGER,
// Read the whole message into RAM upfront, in the MessageReader constructor. When reading from
// an InputStream, the stream will then be positioned at the byte immediately after the end of
// the message, and will not be accessed again.
LAZY,
// Read segments of the message into RAM as needed while the message structure is being traversed.
// When reading from an InputStream, segments must be read in order, so segments up to the
// required segment will also be read. No guarantee is made about the position of the InputStream
// after reading. When using an InputFile, only the exact segments desired are read.
// Always writes the full size. Throws exception on error.
EAGER_WAIT_FOR_READ_NEXT,
// Like EAGER but don't read the first mesasge until readNext() is called the first time.
LAZY_WAIT_FOR_READ_NEXT,
// Like LAZY but don't read the first mesasge until readNext() is called the first time.
virtual void write(ArrayPtr<const ArrayPtr<const byte>> pieces);
// Equivalent to write()ing each byte array in sequence, which is what the default implementation
// does. Override if you can do something better, e.g. use writev() to do the write in a single
// syscall.
};
class InputStreamMessageReader: public MessageReader {
public:
InputStreamMessageReader(InputStream* inputStream,
InputStreamMessageReader(InputStream& inputStream,
ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER);
void readNext();
// Progress to the next message in the input stream, reusing the same memory if possible.
// Calling this invalidates any Readers currently pointing into this message.
ArrayPtr<word> scratchSpace = nullptr);
~InputStreamMessageReader();
// implements MessageReader ----------------------------------------
ArrayPtr<const word> getSegment(uint id) override;
private:
InputStream* inputStream;
InputStrategy inputStrategy;
uint segmentsReadSoFar;
struct LazySegment {
uint size;
Array<word> words;
// words may be larger than the desired size in the case where space is being reused from a
// previous read.
inline LazySegment(): size(0), words(nullptr) {}
};
InputStream& inputStream;
byte* readPos;
// Optimize for single-segment case.
LazySegment segment0;
Array<LazySegment> moreSegments;
void readNextInternal();
};
class InputFileMessageReader: public MessageReader {
public:
InputFileMessageReader(InputFile* inputFile,
ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER);
// implements MessageReader ----------------------------------------
ArrayPtr<const word> getSegment(uint id) override;
private:
InputFile* inputFile;
struct LazySegment {
uint size;
size_t offset;
Array<word> words; // null until actually read
inline LazySegment(): size(0), offset(0), words(nullptr) {}
};
ArrayPtr<const word> segment0;
Array<ArrayPtr<const word>> moreSegments;
// Optimize for single-segment case.
LazySegment segment0;
Array<LazySegment> moreSegments;
Array<word> ownedSpace;
// Only if scratchSpace wasn't big enough.
};
void writeMessage(OutputStream& output, MessageBuilder& builder);
......@@ -218,39 +174,15 @@ class FdInputStream: public InputStream {
// An InputStream wrapping a file descriptor.
public:
FdInputStream(int fd, ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), errorReporter(errorReporter) {};
FdInputStream(AutoCloseFd fd, ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), autoclose(move(fd)), errorReporter(errorReporter) {}
FdInputStream(int fd): fd(fd) {};
FdInputStream(AutoCloseFd fd): fd(fd), autoclose(move(fd)) {}
~FdInputStream();
bool read(void* buffer, size_t size) override;
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
private:
int fd;
AutoCloseFd autoclose;
ErrorReporter* errorReporter;
};
class FdInputFile: public InputFile {
// An InputFile wrapping a file descriptor. The file descriptor must be seekable. This
// implementation uses pread(), so the stream pointer will not be modified.
public:
FdInputFile(int fd, size_t offset, ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), offset(offset), errorReporter(errorReporter) {};
FdInputFile(AutoCloseFd fd, size_t offset,
ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), autoclose(move(fd)), offset(offset), errorReporter(errorReporter) {}
~FdInputFile();
bool read(size_t offset, void* buffer, size_t size) override;
private:
int fd;
AutoCloseFd autoclose;
size_t offset;
ErrorReporter* errorReporter;
};
class FdOutputStream: public OutputStream {
......@@ -262,6 +194,7 @@ public:
~FdOutputStream();
void write(const void* buffer, size_t size) override;
void write(ArrayPtr<const ArrayPtr<const byte>> pieces) override;
private:
int fd;
......@@ -274,8 +207,8 @@ class StreamFdMessageReader: private FdInputStream, public InputStreamMessageRea
public:
StreamFdMessageReader(int fd, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER)
: FdInputStream(fd), InputStreamMessageReader(this, options, inputStrategy) {}
ArrayPtr<word> scratchSpace = nullptr)
: FdInputStream(fd), InputStreamMessageReader(*this, options, scratchSpace) {}
// Read message from a file descriptor, without taking ownership of the descriptor.
//
// Since this version implies that the caller intends to read more data from the fd later on, the
......@@ -283,8 +216,8 @@ public:
// deterministically positioned just past the end of the message.
StreamFdMessageReader(AutoCloseFd fd, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY)
: FdInputStream(move(fd)), InputStreamMessageReader(this, options, inputStrategy) {}
ArrayPtr<word> scratchSpace = nullptr)
: FdInputStream(move(fd)), InputStreamMessageReader(*this, options, scratchSpace) {}
// Read a message from a file descriptor, taking ownership of the descriptor.
//
// Since this version implies that the caller does not intend to read any more data from the fd,
......@@ -293,31 +226,6 @@ public:
~StreamFdMessageReader();
};
class FileFdMessageReader: private FdInputFile, public InputFileMessageReader {
// A MessageReader that reads from a seekable file descriptor, e.g. disk files. For non-seekable
// file descriptors, use FdStreamMessageReader. This implementation uses pread(), so the file
// descriptor's stream pointer will not be modified.
public:
FileFdMessageReader(int fd, size_t offset, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY)
: FdInputFile(fd, offset, options.errorReporter),
InputFileMessageReader(this, options, inputStrategy) {}
// Read message from a file descriptor, without taking ownership of the descriptor.
//
// All reads use pread(), so the file descriptor's stream pointer will not be modified.
FileFdMessageReader(AutoCloseFd fd, size_t offset, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY)
: FdInputFile(move(fd), offset, options.errorReporter),
InputFileMessageReader(this, options, inputStrategy) {}
// Read a message from a file descriptor, taking ownership of the descriptor.
//
// All reads use pread(), so the file descriptor's stream pointer will not be modified.
~FileFdMessageReader();
};
void writeMessageToFd(int fd, MessageBuilder& builder);
// Write the message to the given file descriptor.
//
......
......@@ -80,6 +80,8 @@ public:
inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; }
inline T& front() const { return *ptr; }
inline T& back() const { return *(ptr + size_ - 1); }
inline ArrayPtr slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice().");
......@@ -137,6 +139,8 @@ public:
inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; }
inline T& front() const { return *ptr; }
inline T& back() const { return *(ptr + size_ - 1); }
inline ArrayPtr<T> slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds Array::slice().");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment