Commit 2bcb6e69 authored by Kenton Varda's avatar Kenton Varda

Simplify reuse strategy, add reuse comparisons to benchmarks.

parent 15aa868b
...@@ -42,16 +42,6 @@ ReaderArena::ReaderArena(MessageReader* message) ...@@ -42,16 +42,6 @@ ReaderArena::ReaderArena(MessageReader* message)
ReaderArena::~ReaderArena() {} ReaderArena::~ReaderArena() {}
void ReaderArena::reset() {
readLimiter.reset(message->getOptions().traversalLimitInWords * WORDS);
ignoreErrors = false;
segment0.~SegmentReader();
new(&segment0) SegmentReader(this, SegmentId(0), this->message->getSegment(0), &readLimiter);
// TODO: Reuse the rest of the SegmentReaders?
moreSegments = nullptr;
}
SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
if (segment0.getArray() == nullptr) { if (segment0.getArray() == nullptr) {
...@@ -110,16 +100,6 @@ BuilderArena::BuilderArena(MessageBuilder* message) ...@@ -110,16 +100,6 @@ BuilderArena::BuilderArena(MessageBuilder* message)
: message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {} : message(message), segment0(nullptr, SegmentId(0), nullptr, nullptr) {}
BuilderArena::~BuilderArena() {} BuilderArena::~BuilderArena() {}
void BuilderArena::reset() {
segment0.reset();
if (moreSegments != nullptr) {
// TODO: As mentioned in another TODO below, only the last segment will only be reused.
for (auto& segment: moreSegments->builders) {
segment->reset();
}
}
}
SegmentBuilder* BuilderArena::getSegment(SegmentId id) { SegmentBuilder* BuilderArena::getSegment(SegmentId id) {
// This method is allowed to crash if the segment ID is not valid. // This method is allowed to crash if the segment ID is not valid.
if (id == SegmentId(0)) { if (id == SegmentId(0)) {
......
...@@ -164,8 +164,6 @@ public: ...@@ -164,8 +164,6 @@ public:
~ReaderArena(); ~ReaderArena();
CAPNPROTO_DISALLOW_COPY(ReaderArena); CAPNPROTO_DISALLOW_COPY(ReaderArena);
void reset();
// implements Arena ------------------------------------------------ // implements Arena ------------------------------------------------
SegmentReader* tryGetSegment(SegmentId id) override; SegmentReader* tryGetSegment(SegmentId id) override;
void reportInvalidData(const char* description) override; void reportInvalidData(const char* description) override;
...@@ -189,9 +187,6 @@ public: ...@@ -189,9 +187,6 @@ public:
~BuilderArena(); ~BuilderArena();
CAPNPROTO_DISALLOW_COPY(BuilderArena); CAPNPROTO_DISALLOW_COPY(BuilderArena);
void reset();
// Resets all the segments to be empty, so that a new message can be started.
SegmentBuilder* getSegment(SegmentId id); SegmentBuilder* getSegment(SegmentId id);
// Get the segment with the given id. Crashes or throws an exception if no such segment exists. // Get the segment with the given id. Crashes or throws an exception if no such segment exists.
......
...@@ -185,83 +185,130 @@ public: ...@@ -185,83 +185,130 @@ public:
// ======================================================================================= // =======================================================================================
template <typename TestCase> struct NoScratch {
struct ScratchSpace {};
class MessageReader: public StreamFdMessageReader {
public:
inline MessageReader(int fd, ScratchSpace& scratch)
: StreamFdMessageReader(fd) {}
};
class MessageBuilder: public MallocMessageBuilder {
public:
inline MessageBuilder(ScratchSpace& scratch): MallocMessageBuilder() {}
};
};
template <size_t size>
struct UseScratch {
struct ScratchSpace {
word words[size];
};
class MessageReader: public StreamFdMessageReader {
public:
inline MessageReader(int fd, ScratchSpace& scratch)
: StreamFdMessageReader(fd, ReaderOptions(), arrayPtr(scratch.words, size)) {}
};
class MessageBuilder: public MallocMessageBuilder {
public:
inline MessageBuilder(ScratchSpace& scratch)
: MallocMessageBuilder(arrayPtr(scratch.words, size)) {}
};
};
// =======================================================================================
template <typename TestCase, typename ReuseStrategy>
void syncClient(int inputFd, int outputFd, uint64_t iters) { void syncClient(int inputFd, int outputFd, uint64_t iters) {
MallocMessageBuilder builder; typename ReuseStrategy::ScratchSpace scratch;
// StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT);
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
typename TestCase::Expectation expected = TestCase::setupRequest( typename TestCase::Expectation expected;
builder.initRoot<typename TestCase::Request>()); {
writeMessageToFd(outputFd, builder); typename ReuseStrategy::MessageBuilder builder(scratch);
expected = TestCase::setupRequest(
builder.template initRoot<typename TestCase::Request>());
writeMessageToFd(outputFd, builder);
}
// reader.readNext(); {
StreamFdMessageReader reader(inputFd); typename ReuseStrategy::MessageReader reader(inputFd, scratch);
if (!TestCase::checkResponse(reader.getRoot<typename TestCase::Response>(), expected)) { if (!TestCase::checkResponse(
throw std::logic_error("Incorrect response."); reader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response.");
}
} }
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClientSender(int outputFd, void asyncClientSender(int outputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations, ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) { uint64_t iters) {
MallocMessageBuilder builder; typename ReuseStrategy::ScratchSpace scratch;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
expectations->post(TestCase::setupRequest(builder.initRoot<typename TestCase::Request>())); typename ReuseStrategy::MessageBuilder builder(scratch);
expectations->post(TestCase::setupRequest(
builder.template initRoot<typename TestCase::Request>()));
writeMessageToFd(outputFd, builder); writeMessageToFd(outputFd, builder);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClientReceiver(int inputFd, void asyncClientReceiver(int inputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations, ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) { uint64_t iters) {
StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT); typename ReuseStrategy::ScratchSpace scratch;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
typename TestCase::Expectation expected = expectations->next(); typename TestCase::Expectation expected = expectations->next();
reader.readNext(); typename ReuseStrategy::MessageReader reader(inputFd, scratch);
if (!TestCase::checkResponse(reader.getRoot<typename TestCase::Response>(), expected)) { if (!TestCase::checkResponse(
reader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClient(int inputFd, int outputFd, uint64_t iters) { void asyncClient(int inputFd, int outputFd, uint64_t iters) {
ProducerConsumerQueue<typename TestCase::Expectation> expectations; ProducerConsumerQueue<typename TestCase::Expectation> expectations;
std::thread receiverThread(asyncClientReceiver<TestCase>, inputFd, &expectations, iters); std::thread receiverThread(
asyncClientSender<TestCase>(outputFd, &expectations, iters); asyncClientReceiver<TestCase, ReuseStrategy>, inputFd, &expectations, iters);
asyncClientSender<TestCase, ReuseStrategy>(outputFd, &expectations, iters);
receiverThread.join(); receiverThread.join();
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void server(int inputFd, int outputFd, uint64_t iters) { void server(int inputFd, int outputFd, uint64_t iters) {
StreamFdMessageReader reader(inputFd, ReaderOptions(), InputStrategy::EAGER_WAIT_FOR_READ_NEXT); typename ReuseStrategy::ScratchSpace builderScratch;
MallocMessageBuilder builder; typename ReuseStrategy::ScratchSpace readerScratch;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
reader.readNext(); typename ReuseStrategy::MessageBuilder builder(builderScratch);
// StreamFdMessageReader reader(inputFd); typename ReuseStrategy::MessageReader reader(inputFd, readerScratch);
TestCase::handleRequest(reader.getRoot<typename TestCase::Request>(), TestCase::handleRequest(reader.template getRoot<typename TestCase::Request>(),
builder.initRoot<typename TestCase::Response>()); builder.template initRoot<typename TestCase::Response>());
writeMessageToFd(outputFd, builder); writeMessageToFd(outputFd, builder);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void passByObject(uint64_t iters) { void passByObject(uint64_t iters) {
MallocMessageBuilder requestMessage; typename ReuseStrategy::ScratchSpace requestScratch;
MallocMessageBuilder responseMessage; typename ReuseStrategy::ScratchSpace responseScratch;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
auto request = requestMessage.initRoot<typename TestCase::Request>(); typename ReuseStrategy::MessageBuilder requestMessage(requestScratch);
auto request = requestMessage.template initRoot<typename TestCase::Request>();
typename TestCase::Expectation expected = TestCase::setupRequest(request); typename TestCase::Expectation expected = TestCase::setupRequest(request);
auto response = responseMessage.initRoot<typename TestCase::Response>(); typename ReuseStrategy::MessageBuilder responseMessage(responseScratch);
auto response = responseMessage.template initRoot<typename TestCase::Response>();
TestCase::handleRequest(request.asReader(), response); TestCase::handleRequest(request.asReader(), response);
if (!TestCase::checkResponse(response.asReader(), expected)) { if (!TestCase::checkResponse(response.asReader(), expected)) {
...@@ -270,29 +317,32 @@ void passByObject(uint64_t iters) { ...@@ -270,29 +317,32 @@ void passByObject(uint64_t iters) {
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void passByBytes(uint64_t iters) { void passByBytes(uint64_t iters) {
MallocMessageBuilder requestBuilder; typename ReuseStrategy::ScratchSpace requestScratch;
MallocMessageBuilder responseBuilder; typename ReuseStrategy::ScratchSpace responseScratch;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
typename ReuseStrategy::MessageBuilder requestBuilder(requestScratch);
typename TestCase::Expectation expected = TestCase::setupRequest( typename TestCase::Expectation expected = TestCase::setupRequest(
requestBuilder.initRoot<typename TestCase::Request>()); requestBuilder.template initRoot<typename TestCase::Request>());
Array<word> requestBytes = messageToFlatArray(requestBuilder); Array<word> requestBytes = messageToFlatArray(requestBuilder);
FlatArrayMessageReader requestReader(requestBytes.asPtr()); FlatArrayMessageReader requestReader(requestBytes.asPtr());
TestCase::handleRequest(requestReader.getRoot<typename TestCase::Request>(), typename ReuseStrategy::MessageBuilder responseBuilder(responseScratch);
responseBuilder.initRoot<typename TestCase::Response>()); TestCase::handleRequest(requestReader.template getRoot<typename TestCase::Request>(),
responseBuilder.template initRoot<typename TestCase::Response>());
Array<word> responseBytes = messageToFlatArray(responseBuilder); Array<word> responseBytes = messageToFlatArray(responseBuilder);
FlatArrayMessageReader responseReader(responseBytes.asPtr()); FlatArrayMessageReader responseReader(responseBytes.asPtr());
if (!TestCase::checkResponse(responseReader.getRoot<typename TestCase::Response>(), expected)) { if (!TestCase::checkResponse(
responseReader.template getRoot<typename TestCase::Response>(), expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
} }
} }
template <typename TestCase, typename Func> template <typename TestCase, typename ReuseStrategy, typename Func>
void passByPipe(Func&& clientFunc, uint64_t iters) { void passByPipe(Func&& clientFunc, uint64_t iters) {
int clientToServer[2]; int clientToServer[2];
int serverToClient[2]; int serverToClient[2];
...@@ -312,7 +362,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) { ...@@ -312,7 +362,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
close(clientToServer[1]); close(clientToServer[1]);
close(serverToClient[0]); close(serverToClient[0]);
server<TestCase>(clientToServer[0], serverToClient[1], iters); server<TestCase, ReuseStrategy>(clientToServer[0], serverToClient[1], iters);
int status; int status;
if (waitpid(child, &status, 0) != child) { if (waitpid(child, &status, 0) != child) {
...@@ -324,32 +374,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) { ...@@ -324,32 +374,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
} }
} }
template <typename ReuseStrategy>
void doBenchmark(const std::string& mode, uint64_t iters) {
if (mode == "client") {
syncClient<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
syncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
asyncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
exit(1);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 3) { if (argc != 4) {
std::cerr << "USAGE: " << argv[0] << " MODE ITERATION_COUNT" << std::endl; std::cerr << "USAGE: " << argv[0] << " MODE REUSE ITERATION_COUNT" << std::endl;
return 1; return 1;
} }
uint64_t iters = strtoull(argv[2], nullptr, 0); uint64_t iters = strtoull(argv[3], nullptr, 0);
srand(123); srand(123);
std::cerr << "Doing " << iters << " iterations..." << std::endl; std::cerr << "Doing " << iters << " iterations..." << std::endl;
std::string mode = argv[1]; std::string reuse = argv[2];
if (mode == "client") { if (reuse == "reuse") {
syncClient<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters); doBenchmark<UseScratch<1024>>(argv[1], iters);
} else if (mode == "server") { } else if (reuse == "no-reuse") {
server<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters); doBenchmark<NoScratch>(argv[1], iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase>(syncClient<ExpressionTestCase>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase>(asyncClient<ExpressionTestCase>, iters);
} else { } else {
std::cerr << "Unknown mode: " << mode << std::endl; std::cerr << "Unknown reuse mode: " << reuse << std::endl;
return 1; return 1;
} }
......
...@@ -183,7 +183,49 @@ public: ...@@ -183,7 +183,49 @@ public:
// ======================================================================================= // =======================================================================================
void writeProtoFast(const google::protobuf::MessageLite& message, struct SingleUseMessages {
template <typename MessageType>
struct Message {
struct Reusable {};
struct SingleUse: public MessageType {
inline SingleUse(Reusable&) {}
};
};
struct ReusableString {};
struct SingleUseString: std::string {
inline SingleUseString(ReusableString&) {}
};
template <typename MessageType>
static inline void doneWith(MessageType& message) {
// Don't clear -- single-use.
}
};
struct ReusableMessages {
template <typename MessageType>
struct Message {
struct Reusable: public MessageType {};
typedef MessageType& SingleUse;
};
typedef std::string ReusableString;
typedef std::string& SingleUseString;
template <typename MessageType>
static inline void doneWith(MessageType& message) {
message.Clear();
}
};
// =======================================================================================
// The protobuf Java library defines a format for writing multiple protobufs to a stream, in which
// each message is prefixed by a varint size. This was never added to the C++ library. It's easy
// to do naively, but tricky to implement without accidentally losing various optimizations. These
// two functions should be optimal.
void writeDelimited(const google::protobuf::MessageLite& message,
google::protobuf::io::FileOutputStream* rawOutput) { google::protobuf::io::FileOutputStream* rawOutput) {
google::protobuf::io::CodedOutputStream output(rawOutput); google::protobuf::io::CodedOutputStream output(rawOutput);
const int size = message.ByteSize(); const int size = message.ByteSize();
...@@ -199,7 +241,7 @@ void writeProtoFast(const google::protobuf::MessageLite& message, ...@@ -199,7 +241,7 @@ void writeProtoFast(const google::protobuf::MessageLite& message,
} }
} }
void readProtoFast(google::protobuf::io::ZeroCopyInputStream* rawInput, void readDelimited(google::protobuf::io::ZeroCopyInputStream* rawInput,
google::protobuf::MessageLite* message) { google::protobuf::MessageLite* message) {
google::protobuf::io::CodedInputStream input(rawInput); google::protobuf::io::CodedInputStream input(rawInput);
uint32_t size; uint32_t size;
...@@ -217,139 +259,160 @@ void readProtoFast(google::protobuf::io::ZeroCopyInputStream* rawInput, ...@@ -217,139 +259,160 @@ void readProtoFast(google::protobuf::io::ZeroCopyInputStream* rawInput,
input.PopLimit(limit); input.PopLimit(limit);
} }
template <typename TestCase> // =======================================================================================
#define REUSABLE(type) \
typename ReuseStrategy::template Message<typename TestCase::type>::Reusable
#define SINGLE_USE(type) \
typename ReuseStrategy::template Message<typename TestCase::type>::SingleUse
template <typename TestCase, typename ReuseStrategy>
void syncClient(int inputFd, int outputFd, uint64_t iters) { void syncClient(int inputFd, int outputFd, uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd); google::protobuf::io::FileOutputStream output(outputFd);
google::protobuf::io::FileInputStream input(inputFd); google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Request request; REUSABLE(Request) reusableRequest;
typename TestCase::Response response; REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&request); typename TestCase::Expectation expected = TestCase::setupRequest(&request);
writeProtoFast(request, &output); writeDelimited(request, &output);
if (!output.Flush()) throw OsException(output.GetErrno()); if (!output.Flush()) throw OsException(output.GetErrno());
request.Clear(); ReuseStrategy::doneWith(request);
// std::cerr << "client: wait" << std::endl; SINGLE_USE(Response) response(reusableResponse);
readProtoFast(&input, &response); readDelimited(&input, &response);
if (!TestCase::checkResponse(response, expected)) { if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
response.Clear(); ReuseStrategy::doneWith(response);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClientSender(int outputFd, void asyncClientSender(int outputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations, ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) { uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd); google::protobuf::io::FileOutputStream output(outputFd);
typename TestCase::Request request; REUSABLE(Request) reusableRequest;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
expectations->post(TestCase::setupRequest(&request)); expectations->post(TestCase::setupRequest(&request));
writeProtoFast(request, &output); writeDelimited(request, &output);
request.Clear(); ReuseStrategy::doneWith(request);
} }
if (!output.Flush()) throw OsException(output.GetErrno()); if (!output.Flush()) throw OsException(output.GetErrno());
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClientReceiver(int inputFd, void asyncClientReceiver(int inputFd,
ProducerConsumerQueue<typename TestCase::Expectation>* expectations, ProducerConsumerQueue<typename TestCase::Expectation>* expectations,
uint64_t iters) { uint64_t iters) {
google::protobuf::io::FileInputStream input(inputFd); google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Response response; REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
typename TestCase::Expectation expected = expectations->next(); typename TestCase::Expectation expected = expectations->next();
readProtoFast(&input, &response); SINGLE_USE(Response) response(reusableResponse);
readDelimited(&input, &response);
if (!TestCase::checkResponse(response, expected)) { if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
response.Clear(); ReuseStrategy::doneWith(response);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void asyncClient(int inputFd, int outputFd, uint64_t iters) { void asyncClient(int inputFd, int outputFd, uint64_t iters) {
ProducerConsumerQueue<typename TestCase::Expectation> expectations; ProducerConsumerQueue<typename TestCase::Expectation> expectations;
std::thread receiverThread(asyncClientReceiver<TestCase>, inputFd, &expectations, iters); std::thread receiverThread(
asyncClientSender<TestCase>(outputFd, &expectations, iters); asyncClientReceiver<TestCase, ReuseStrategy>, inputFd, &expectations, iters);
asyncClientSender<TestCase, ReuseStrategy>(outputFd, &expectations, iters);
receiverThread.join(); receiverThread.join();
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void server(int inputFd, int outputFd, uint64_t iters) { void server(int inputFd, int outputFd, uint64_t iters) {
google::protobuf::io::FileOutputStream output(outputFd); google::protobuf::io::FileOutputStream output(outputFd);
google::protobuf::io::FileInputStream input(inputFd); google::protobuf::io::FileInputStream input(inputFd);
typename TestCase::Request request; REUSABLE(Request) reusableRequest;
typename TestCase::Response response; REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
readProtoFast(&input, &request); SINGLE_USE(Request) request(reusableRequest);
readDelimited(&input, &request);
SINGLE_USE(Response) response(reusableResponse);
TestCase::handleRequest(request, &response); TestCase::handleRequest(request, &response);
request.Clear(); ReuseStrategy::doneWith(request);
writeProtoFast(response, &output); writeDelimited(response, &output);
if (!output.Flush()) throw std::logic_error("Write failed."); if (!output.Flush()) throw std::logic_error("Write failed.");
response.Clear(); ReuseStrategy::doneWith(response);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void passByObject(uint64_t iters) { void passByObject(uint64_t iters) {
typename TestCase::Request request; REUSABLE(Request) reusableRequest;
typename TestCase::Response response; REUSABLE(Response) reusableResponse;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
SINGLE_USE(Request) request(reusableRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&request); typename TestCase::Expectation expected = TestCase::setupRequest(&request);
SINGLE_USE(Response) response(reusableResponse);
TestCase::handleRequest(request, &response); TestCase::handleRequest(request, &response);
request.Clear(); ReuseStrategy::doneWith(request);
if (!TestCase::checkResponse(response, expected)) { if (!TestCase::checkResponse(response, expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
response.Clear(); ReuseStrategy::doneWith(response);
} }
} }
template <typename TestCase> template <typename TestCase, typename ReuseStrategy>
void passByBytes(uint64_t iters) { void passByBytes(uint64_t iters) {
typename TestCase::Request clientRequest; REUSABLE(Request) reusableClientRequest;
typename TestCase::Request serverRequest; REUSABLE(Request) reusableServerRequest;
typename TestCase::Response serverResponse; REUSABLE(Response) reusableServerResponse;
typename TestCase::Response clientResponse; REUSABLE(Response) reusableClientResponse;
std::string requestString, responseString; typename ReuseStrategy::ReusableString reusableRequestString, reusableResponseString;
for (; iters > 0; --iters) { for (; iters > 0; --iters) {
SINGLE_USE(Request) clientRequest(reusableClientRequest);
typename TestCase::Expectation expected = TestCase::setupRequest(&clientRequest); typename TestCase::Expectation expected = TestCase::setupRequest(&clientRequest);
typename ReuseStrategy::SingleUseString requestString(reusableRequestString);
clientRequest.SerializePartialToString(&requestString); clientRequest.SerializePartialToString(&requestString);
clientRequest.Clear(); ReuseStrategy::doneWith(clientRequest);
SINGLE_USE(Request) serverRequest(reusableServerRequest);
serverRequest.ParsePartialFromString(requestString); serverRequest.ParsePartialFromString(requestString);
requestString.clear();
SINGLE_USE(Response) serverResponse(reusableServerResponse);
TestCase::handleRequest(serverRequest, &serverResponse); TestCase::handleRequest(serverRequest, &serverResponse);
serverRequest.Clear(); ReuseStrategy::doneWith(serverRequest);
typename ReuseStrategy::SingleUseString responseString(reusableResponseString);
serverResponse.SerializePartialToString(&responseString); serverResponse.SerializePartialToString(&responseString);
serverResponse.Clear(); ReuseStrategy::doneWith(serverResponse);
SINGLE_USE(Response) clientResponse(reusableClientResponse);
clientResponse.ParsePartialFromString(responseString); clientResponse.ParsePartialFromString(responseString);
responseString.clear();
if (!TestCase::checkResponse(clientResponse, expected)) { if (!TestCase::checkResponse(clientResponse, expected)) {
throw std::logic_error("Incorrect response."); throw std::logic_error("Incorrect response.");
} }
clientResponse.Clear(); ReuseStrategy::doneWith(clientResponse);
} }
} }
template <typename TestCase, typename Func> template <typename TestCase, typename ReuseStrategy, typename Func>
void passByPipe(Func&& clientFunc, uint64_t iters) { void passByPipe(Func&& clientFunc, uint64_t iters) {
int clientToServer[2]; int clientToServer[2];
int serverToClient[2]; int serverToClient[2];
...@@ -369,7 +432,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) { ...@@ -369,7 +432,7 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
close(clientToServer[1]); close(clientToServer[1]);
close(serverToClient[0]); close(serverToClient[0]);
server<TestCase>(clientToServer[0], serverToClient[1], iters); server<TestCase, ReuseStrategy>(clientToServer[0], serverToClient[1], iters);
int status; int status;
if (waitpid(child, &status, 0) != child) { if (waitpid(child, &status, 0) != child) {
...@@ -381,32 +444,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) { ...@@ -381,32 +444,46 @@ void passByPipe(Func&& clientFunc, uint64_t iters) {
} }
} }
template <typename ReuseStrategy>
void doBenchmark(const std::string& mode, uint64_t iters) {
if (mode == "client") {
syncClient<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "server") {
server<ExpressionTestCase, ReuseStrategy>(STDIN_FILENO, STDOUT_FILENO, iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase, ReuseStrategy>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
syncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase, ReuseStrategy>(
asyncClient<ExpressionTestCase, ReuseStrategy>, iters);
} else {
std::cerr << "Unknown mode: " << mode << std::endl;
exit(1);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 3) { if (argc != 4) {
std::cerr << "USAGE: " << argv[0] << " MODE ITERATION_COUNT" << std::endl; std::cerr << "USAGE: " << argv[0] << " MODE REUSE ITERATION_COUNT" << std::endl;
return 1; return 1;
} }
uint64_t iters = strtoull(argv[2], nullptr, 0); uint64_t iters = strtoull(argv[3], nullptr, 0);
srand(123); srand(123);
std::cerr << "Doing " << iters << " iterations..." << std::endl; std::cerr << "Doing " << iters << " iterations..." << std::endl;
std::string mode = argv[1]; std::string reuse = argv[2];
if (mode == "client") { if (reuse == "reuse") {
syncClient<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters); doBenchmark<ReusableMessages>(argv[1], iters);
} else if (mode == "server") { } else if (reuse == "no-reuse") {
server<ExpressionTestCase>(STDIN_FILENO, STDOUT_FILENO, iters); doBenchmark<SingleUseMessages>(argv[1], iters);
} else if (mode == "object") {
passByObject<ExpressionTestCase>(iters);
} else if (mode == "bytes") {
passByBytes<ExpressionTestCase>(iters);
} else if (mode == "pipe") {
passByPipe<ExpressionTestCase>(syncClient<ExpressionTestCase>, iters);
} else if (mode == "pipe-async") {
passByPipe<ExpressionTestCase>(asyncClient<ExpressionTestCase>, iters);
} else { } else {
std::cerr << "Unknown mode: " << mode << std::endl; std::cerr << "Unknown reuse mode: " << reuse << std::endl;
return 1; return 1;
} }
......
...@@ -38,10 +38,6 @@ MessageReader::~MessageReader() { ...@@ -38,10 +38,6 @@ MessageReader::~MessageReader() {
} }
} }
void MessageReader::reset() {
if (allocatedArena) arena()->reset();
}
internal::StructReader MessageReader::getRoot(const word* defaultValue) { internal::StructReader MessageReader::getRoot(const word* defaultValue) {
if (!allocatedArena) { if (!allocatedArena) {
static_assert(sizeof(internal::ReaderArena) <= sizeof(arenaSpace), static_assert(sizeof(internal::ReaderArena) <= sizeof(arenaSpace),
...@@ -71,35 +67,35 @@ MessageBuilder::~MessageBuilder() { ...@@ -71,35 +67,35 @@ MessageBuilder::~MessageBuilder() {
} }
} }
internal::SegmentBuilder* MessageBuilder::allocateRootSegment() { internal::SegmentBuilder* MessageBuilder::getRootSegment() {
if (!allocatedArena) { if (allocatedArena) {
return arena()->getSegment(SegmentId(0));
} else {
static_assert(sizeof(internal::BuilderArena) <= sizeof(arenaSpace), static_assert(sizeof(internal::BuilderArena) <= sizeof(arenaSpace),
"arenaSpace is too small to hold a BuilderArena. Please increase it. This will break " "arenaSpace is too small to hold a BuilderArena. Please increase it. This will break "
"ABI compatibility."); "ABI compatibility.");
new(arena()) internal::BuilderArena(this); new(arena()) internal::BuilderArena(this);
allocatedArena = true; allocatedArena = true;
}
WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE; WordCount refSize = 1 * REFERENCES * WORDS_PER_REFERENCE;
internal::SegmentBuilder* segment = arena()->getSegmentWithAvailable(refSize); internal::SegmentBuilder* segment = arena()->getSegmentWithAvailable(refSize);
CAPNPROTO_ASSERT(segment->getSegmentId() == SegmentId(0), CAPNPROTO_ASSERT(segment->getSegmentId() == SegmentId(0),
"First allocated word of new arena was not in segment ID 0."); "First allocated word of new arena was not in segment ID 0.");
word* location = segment->allocate(refSize); word* location = segment->allocate(refSize);
CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS), CAPNPROTO_ASSERT(location == segment->getPtrUnchecked(0 * WORDS),
"First allocated word of new arena was not the first word in its segment."); "First allocated word of new arena was not the first word in its segment.");
return segment; return segment;
}
} }
internal::StructBuilder MessageBuilder::initRoot(const word* defaultValue) { internal::StructBuilder MessageBuilder::initRoot(const word* defaultValue) {
if (allocatedArena) arena()->reset(); internal::SegmentBuilder* rootSegment = getRootSegment();
internal::SegmentBuilder* rootSegment = allocateRootSegment();
return internal::StructBuilder::initRoot( return internal::StructBuilder::initRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue); rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
} }
internal::StructBuilder MessageBuilder::getRoot(const word* defaultValue) { internal::StructBuilder MessageBuilder::getRoot(const word* defaultValue) {
internal::SegmentBuilder* rootSegment = allocatedArena ? internal::SegmentBuilder* rootSegment = getRootSegment();
arena()->getSegment(SegmentId(0)) : allocateRootSegment();
return internal::StructBuilder::getRoot( return internal::StructBuilder::getRoot(
rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue); rootSegment, rootSegment->getPtrUnchecked(0 * WORDS), defaultValue);
} }
...@@ -200,10 +196,15 @@ struct MallocMessageBuilder::MoreSegments { ...@@ -200,10 +196,15 @@ struct MallocMessageBuilder::MoreSegments {
MallocMessageBuilder::MallocMessageBuilder( MallocMessageBuilder::MallocMessageBuilder(
uint firstSegmentWords, AllocationStrategy allocationStrategy) uint firstSegmentWords, AllocationStrategy allocationStrategy)
: nextSize(firstSegmentWords), allocationStrategy(allocationStrategy), : nextSize(firstSegmentWords), allocationStrategy(allocationStrategy),
firstSegment(nullptr) {} ownFirstSegment(true), firstSegment(nullptr) {}
MallocMessageBuilder::MallocMessageBuilder(
ArrayPtr<word> firstSegment, AllocationStrategy allocationStrategy)
: nextSize(firstSegment.size()), allocationStrategy(allocationStrategy),
ownFirstSegment(false), firstSegment(firstSegment.begin()) {}
MallocMessageBuilder::~MallocMessageBuilder() { MallocMessageBuilder::~MallocMessageBuilder() {
free(firstSegment); if (ownFirstSegment) free(firstSegment);
if (moreSegments != nullptr) { if (moreSegments != nullptr) {
for (void* ptr: moreSegments->segments) { for (void* ptr: moreSegments->segments) {
free(ptr); free(ptr);
...@@ -212,6 +213,19 @@ MallocMessageBuilder::~MallocMessageBuilder() { ...@@ -212,6 +213,19 @@ MallocMessageBuilder::~MallocMessageBuilder() {
} }
ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) { ArrayPtr<word> MallocMessageBuilder::allocateSegment(uint minimumSize) {
if (!ownFirstSegment) {
ArrayPtr<word> result = arrayPtr(reinterpret_cast<word*>(firstSegment), nextSize);
firstSegment = nullptr;
ownFirstSegment = true;
if (result.size() >= minimumSize) {
memset(result.begin(), 0, result.size() * sizeof(word));
return result;
}
// If the provided first segment wasn't big enough, we discard it and proceed to allocate
// our own. This never happens in practice since minimumSize is always 1 for the first
// segment.
}
uint size = std::max(minimumSize, nextSize); uint size = std::max(minimumSize, nextSize);
void* result = calloc(size, sizeof(word)); void* result = calloc(size, sizeof(word));
......
...@@ -133,14 +133,6 @@ public: ...@@ -133,14 +133,6 @@ public:
template <typename RootType> template <typename RootType>
typename RootType::Reader getRoot(); typename RootType::Reader getRoot();
protected:
void reset();
// Clear the cached segment table so that the reader can be reused to read another message.
// reset() may call getSegment() again before returning, so you must arrange for the new segment
// set to be active *before* calling this.
//
// This invalidates any Readers currently pointing into this message.
private: private:
ReaderOptions options; ReaderOptions options;
...@@ -182,7 +174,7 @@ private: ...@@ -182,7 +174,7 @@ private:
bool allocatedArena = false; bool allocatedArena = false;
internal::BuilderArena* arena() { return reinterpret_cast<internal::BuilderArena*>(arenaSpace); } internal::BuilderArena* arena() { return reinterpret_cast<internal::BuilderArena*>(arenaSpace); }
internal::SegmentBuilder* allocateRootSegment(); internal::SegmentBuilder* getRootSegment();
internal::StructBuilder initRoot(const word* defaultValue); internal::StructBuilder initRoot(const word* defaultValue);
internal::StructBuilder getRoot(const word* defaultValue); internal::StructBuilder getRoot(const word* defaultValue);
}; };
...@@ -232,7 +224,7 @@ private: ...@@ -232,7 +224,7 @@ private:
ArrayPtr<const ArrayPtr<const word>> segments; ArrayPtr<const ArrayPtr<const word>> segments;
}; };
enum class AllocationStrategy { enum class AllocationStrategy: uint8_t {
FIXED_SIZE, FIXED_SIZE,
// The builder will prefer to allocate the same amount of space for each segment with no // The builder will prefer to allocate the same amount of space for each segment with no
// heuristic growth. It will still allocate larger segments when the preferred size is too small // heuristic growth. It will still allocate larger segments when the preferred size is too small
...@@ -257,7 +249,7 @@ class MallocMessageBuilder: public MessageBuilder { ...@@ -257,7 +249,7 @@ class MallocMessageBuilder: public MessageBuilder {
// a specific location in memory. // a specific location in memory.
public: public:
MallocMessageBuilder(uint firstSegmentWords = 1024, explicit MallocMessageBuilder(uint firstSegmentWords = 1024,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY); AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// Creates a BuilderContext which allocates at least the given number of words for the first // Creates a BuilderContext which allocates at least the given number of words for the first
// segment, and then uses the given strategy to decide how much to allocate for subsequent // segment, and then uses the given strategy to decide how much to allocate for subsequent
...@@ -271,6 +263,12 @@ public: ...@@ -271,6 +263,12 @@ public:
// The defaults have been chosen to be reasonable for most people, so don't change them unless you // The defaults have been chosen to be reasonable for most people, so don't change them unless you
// have reason to believe you need to. // have reason to believe you need to.
explicit MallocMessageBuilder(ArrayPtr<word> firstSegment,
AllocationStrategy allocationStrategy = SUGGESTED_ALLOCATION_STRATEGY);
// This version always returns the given array for the first segment, and then proceeds with the
// allocation strategy. This is useful for optimization when building lots of small messages in
// a tight loop: you can reuse the space for the first segment.
CAPNPROTO_DISALLOW_COPY(MallocMessageBuilder); CAPNPROTO_DISALLOW_COPY(MallocMessageBuilder);
virtual ~MallocMessageBuilder(); virtual ~MallocMessageBuilder();
...@@ -280,6 +278,7 @@ private: ...@@ -280,6 +278,7 @@ private:
uint nextSize; uint nextSize;
AllocationStrategy allocationStrategy; AllocationStrategy allocationStrategy;
bool ownFirstSegment;
void* firstSegment; void* firstSegment;
struct MoreSegments; struct MoreSegments;
......
...@@ -83,7 +83,7 @@ TEST(Serialize, FlatArrayOddSegmentCount) { ...@@ -83,7 +83,7 @@ TEST(Serialize, FlatArrayOddSegmentCount) {
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, FlatArrayEventSegmentCount) { TEST(Serialize, FlatArrayEvenSegmentCount) {
TestMessageBuilder builder(10); TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
...@@ -95,25 +95,24 @@ TEST(Serialize, FlatArrayEventSegmentCount) { ...@@ -95,25 +95,24 @@ TEST(Serialize, FlatArrayEventSegmentCount) {
class TestInputStream: public InputStream { class TestInputStream: public InputStream {
public: public:
TestInputStream(ArrayPtr<const word> data) TestInputStream(ArrayPtr<const word> data, bool lazy)
: pos(reinterpret_cast<const char*>(data.begin())), : pos(reinterpret_cast<const char*>(data.begin())),
end(reinterpret_cast<const char*>(data.end())) {} end(reinterpret_cast<const char*>(data.end())),
lazy(lazy) {}
~TestInputStream() {} ~TestInputStream() {}
bool read(void* buffer, size_t size) override { size_t read(void* buffer, size_t minBytes, size_t maxBytes) override {
if (size_t(end - pos) < size) { CAPNPROTO_ASSERT(maxBytes <= size_t(end - pos), "Overran end of stream.");
ADD_FAILURE() << "Overran end of stream."; size_t amount = lazy ? minBytes : maxBytes;
return false; memcpy(buffer, pos, amount);
} else { pos += amount;
memcpy(buffer, pos, size); return amount;
pos += size;
return true;
}
} }
private: private:
const char* pos; const char* pos;
const char* end; const char* end;
bool lazy;
}; };
TEST(Serialize, InputStream) { TEST(Serialize, InputStream) {
...@@ -122,162 +121,81 @@ TEST(Serialize, InputStream) { ...@@ -122,162 +121,81 @@ TEST(Serialize, InputStream) {
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamLazy) {
TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamOddSegmentCount) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamOddSegmentCountLazy) {
TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamEventSegmentCount) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::EAGER);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
TEST(Serialize, InputStreamEventSegmentCountLazy) {
TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder);
TestInputStream stream(serialized.asPtr());
InputStreamMessageReader reader(&stream, ReaderOptions(), InputStrategy::LAZY);
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
class TestInputFile: public InputFile { TEST(Serialize, InputStreamScratchSpace) {
public:
TestInputFile(ArrayPtr<const word> data)
: begin(reinterpret_cast<const char*>(data.begin())),
size_(data.size() * sizeof(word)) {}
~TestInputFile() {}
bool read(size_t offset, void* buffer, size_t size) override {
if (size_ < offset + size) {
ADD_FAILURE() << "Overran end of file.";
return false;
} else {
memcpy(buffer, begin + offset, size);
return true;
}
}
private:
const char* begin;
size_t size_;
};
TEST(Serialize, InputFile) {
TestMessageBuilder builder(1); TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); word scratch[4096];
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER); TestInputStream stream(serialized.asPtr(), false);
InputStreamMessageReader reader(stream, ReaderOptions(), ArrayPtr<word>(scratch, 4096));
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, InputFileLazy) { TEST(Serialize, InputStreamLazy) {
TestMessageBuilder builder(1); TestMessageBuilder builder(1);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), true);
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, InputFileOddSegmentCount) { TEST(Serialize, InputStreamOddSegmentCount) {
TestMessageBuilder builder(7); TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), false);
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, InputFileOddSegmentCountLazy) { TEST(Serialize, InputStreamOddSegmentCountLazy) {
TestMessageBuilder builder(7); TestMessageBuilder builder(7);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), true);
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, InputFileEventSegmentCount) { TEST(Serialize, InputStreamEvenSegmentCount) {
TestMessageBuilder builder(10); TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), false);
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::EAGER); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
TEST(Serialize, InputFileEventSegmentCountLazy) { TEST(Serialize, InputStreamEvenSegmentCountLazy) {
TestMessageBuilder builder(10); TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
Array<word> serialized = messageToFlatArray(builder); Array<word> serialized = messageToFlatArray(builder);
TestInputFile file(serialized.asPtr()); TestInputStream stream(serialized.asPtr(), true);
InputFileMessageReader reader(&file, ReaderOptions(), InputStrategy::LAZY); InputStreamMessageReader reader(stream, ReaderOptions());
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
...@@ -324,7 +242,7 @@ TEST(Serialize, WriteMessageOddSegmentCount) { ...@@ -324,7 +242,7 @@ TEST(Serialize, WriteMessageOddSegmentCount) {
EXPECT_TRUE(output.dataEquals(serialized.asPtr())); EXPECT_TRUE(output.dataEquals(serialized.asPtr()));
} }
TEST(Serialize, WriteMessageEventSegmentCount) { TEST(Serialize, WriteMessageEvenSegmentCount) {
TestMessageBuilder builder(10); TestMessageBuilder builder(10);
initTestMessage(builder.initRoot<TestAllTypes>()); initTestMessage(builder.initRoot<TestAllTypes>());
...@@ -363,22 +281,10 @@ TEST(Serialize, FileDescriptors) { ...@@ -363,22 +281,10 @@ TEST(Serialize, FileDescriptors) {
checkTestMessage(reader.getRoot<TestAllTypes>()); checkTestMessage(reader.getRoot<TestAllTypes>());
} }
{
FileFdMessageReader reader(tmpfile.get(), 0);
checkTestMessage(reader.getRoot<TestAllTypes>());
}
size_t secondStart = lseek(tmpfile, 0, SEEK_CUR);
{ {
StreamFdMessageReader reader(tmpfile.get()); StreamFdMessageReader reader(tmpfile.get());
EXPECT_EQ("second message in file", reader.getRoot<TestAllTypes>().getTextField()); EXPECT_EQ("second message in file", reader.getRoot<TestAllTypes>().getTextField());
} }
{
FileFdMessageReader reader(move(tmpfile), secondStart);
EXPECT_EQ("second message in file", reader.getRoot<TestAllTypes>().getTextField());
}
} }
// TODO: Test error cases. // TODO: Test error cases.
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <string.h> #include <string.h>
#include <errno.h> #include <errno.h>
#include <unistd.h> #include <unistd.h>
#include <string>
#include <sys/uio.h>
namespace capnproto { namespace capnproto {
...@@ -127,179 +129,100 @@ Array<word> messageToFlatArray(ArrayPtr<const ArrayPtr<const word>> segments) { ...@@ -127,179 +129,100 @@ Array<word> messageToFlatArray(ArrayPtr<const ArrayPtr<const word>> segments) {
// ======================================================================================= // =======================================================================================
InputStream::~InputStream() {} InputStream::~InputStream() {}
InputFile::~InputFile() {}
OutputStream::~OutputStream() {} OutputStream::~OutputStream() {}
// ------------------------------------------------------------------- void InputStream::skip(size_t bytes) {
char scratch[8192];
InputStreamMessageReader::InputStreamMessageReader( while (bytes > 0) {
InputStream* inputStream, ReaderOptions options, InputStrategy inputStrategy) size_t amount = std::min(bytes, sizeof(scratch));
: MessageReader(options), inputStream(inputStream), inputStrategy(inputStrategy), bytes -= read(scratch, amount, amount);
segmentsReadSoFar(0) {
switch (inputStrategy) {
case InputStrategy::EAGER:
case InputStrategy::LAZY:
readNextInternal();
break;
case InputStrategy::EAGER_WAIT_FOR_READ_NEXT:
case InputStrategy::LAZY_WAIT_FOR_READ_NEXT:
break;
} }
} }
void InputStreamMessageReader::readNext() { void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
bool needReset = false; for (auto piece: pieces) {
write(piece.begin(), piece.size());
switch (inputStrategy) {
case InputStrategy::LAZY:
if (moreSegments != nullptr || segment0.size != 0) {
// Make sure we've finished reading the previous message.
// Note that this sort of defeats the purpose of lazy parsing. In theory we could be a
// little more efficient by reading into a stack-allocated scratch buffer rather than
// allocating space for the remaining segments, but people really shouldn't be using
// readNext() when lazy-parsing anyway.
getSegment(moreSegments.size());
}
// no break
case InputStrategy::EAGER:
needReset = true;
// TODO: Save moreSegments for reuse?
moreSegments = nullptr;
segmentsReadSoFar = 0;
segment0.size = 0;
break;
case InputStrategy::EAGER_WAIT_FOR_READ_NEXT:
this->inputStrategy = InputStrategy::EAGER;
break;
case InputStrategy::LAZY_WAIT_FOR_READ_NEXT:
this->inputStrategy = InputStrategy::LAZY;
break;
}
if (inputStream != nullptr) {
readNextInternal();
} }
if (needReset) reset();
} }
void InputStreamMessageReader::readNextInternal() { // -------------------------------------------------------------------
InputStreamMessageReader::InputStreamMessageReader(
InputStream& inputStream, ReaderOptions options, ArrayPtr<word> scratchSpace)
: MessageReader(options), inputStream(inputStream), readPos(nullptr) {
internal::WireValue<uint32_t> firstWord[2]; internal::WireValue<uint32_t> firstWord[2];
if (!inputStream->read(firstWord, sizeof(firstWord))) return; inputStream.read(firstWord, sizeof(firstWord), sizeof(firstWord));
uint segmentCount = firstWord[0].get(); uint segmentCount = firstWord[0].get();
segment0.size = segmentCount == 0 ? 0 : firstWord[1].get(); uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get();
if (segmentCount > 1) { size_t totalWords = segment0Size;
internal::WireValue<uint32_t> sizes[segmentCount - 1];
if (!inputStream->read(sizes, sizeof(sizes))) return;
moreSegments = newArray<LazySegment>(segmentCount - 1); // Read sizes for all segments except the first. Include padding if necessary.
for (uint i = 1; i < segmentCount; i++) { internal::WireValue<uint32_t> moreSizes[segmentCount & ~1];
moreSegments[i - 1].size = sizes[i - 1].get(); if (segmentCount > 1) {
inputStream.read(moreSizes, sizeof(moreSizes), sizeof(moreSizes));
for (uint i = 0; i < segmentCount - 1; i++) {
totalWords += moreSizes[i].get();
} }
} }
if (segmentCount % 2 == 0) { if (scratchSpace.size() < totalWords) {
// Read the padding. // TODO: Consider allocating each segment as a separate chunk to reduce memory fragmentation.
uint32_t pad; ownedSpace = newArray<word>(totalWords);
if (!inputStream->read(&pad, sizeof(pad))) return; scratchSpace = ownedSpace;
} }
if (inputStrategy == InputStrategy::EAGER) { segment0 = scratchSpace.slice(0, segment0Size);
getSegment(segmentCount - 1);
}
}
ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
if (id > moreSegments.size()) {
return nullptr;
}
while (segmentsReadSoFar <= id && inputStream != nullptr) { if (segmentCount > 1) {
LazySegment& segment = segmentsReadSoFar == 0 ? segment0 : moreSegments[segmentsReadSoFar - 1]; moreSegments = newArray<ArrayPtr<const word>>(segmentCount - 1);
if (segment.words.size() < segment.size) { size_t offset = segment0Size;
segment.words = newArray<word>(segment.size);
}
if (!inputStream->read(segment.words.begin(), segment.size * sizeof(word))) { for (uint i = 0; i < segmentCount - 1; i++) {
// There was an error but no exception was thrown, so we're supposed to plod along with uint segmentSize = moreSizes[i].get();
// default values. Discard the broken stream. moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize);
inputStream = nullptr; offset += segmentSize;
break;
} }
++segmentsReadSoFar;
} }
LazySegment& segment = id == 0 ? segment0 : moreSegments[id - 1]; if (segmentCount == 1) {
return segment.words.slice(0, segment.size); inputStream.read(scratchSpace.begin(), totalWords * sizeof(word), totalWords * sizeof(word));
} } else if (segmentCount > 1) {
readPos = reinterpret_cast<byte*>(scratchSpace.begin());
// ------------------------------------------------------------------- readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word));
InputFileMessageReader::InputFileMessageReader(
InputFile* inputFile, ReaderOptions options, InputStrategy inputStrategy)
: MessageReader(options), inputFile(inputFile) {
internal::WireValue<uint32_t> firstWord[2];
if (!inputFile->read(0, firstWord, sizeof(firstWord))) return;
uint segmentCount = firstWord[0].get();
segment0.size = segmentCount == 0 ? 0 : firstWord[1].get();
if (segmentCount > 1) {
internal::WireValue<uint32_t> sizes[segmentCount - 1];
if (!inputFile->read(sizeof(firstWord), sizes, sizeof(sizes))) return;
uint64_t offset = (segmentCount / 2 + 1) * sizeof(word);
segment0.offset = offset;
offset += segment0.size * sizeof(word);
moreSegments = newArray<LazySegment>(segmentCount - 1);
for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = sizes[i - 1].get();
moreSegments[i - 1].size = segmentSize;
moreSegments[i - 1].offset = offset;
offset += segmentSize * sizeof(word);
}
} else {
segment0.offset = sizeof(firstWord);
} }
}
if (inputStrategy == InputStrategy::EAGER) { InputStreamMessageReader::~InputStreamMessageReader() {
for (uint i = 0; i < segmentCount; i++) { if (readPos != nullptr) {
getSegment(segmentCount); // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
} // valid.
inputFile = nullptr; const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
inputStream.skip(allEnd - readPos);
} }
} }
ArrayPtr<const word> InputFileMessageReader::getSegment(uint id) { ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
if (id > moreSegments.size()) { if (id > moreSegments.size()) {
return nullptr; return nullptr;
} }
LazySegment& segment = id == 0 ? segment0 : moreSegments[id - 1]; ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1];
if (segment.words == nullptr && segment.size > 0 && inputFile != nullptr) { if (readPos != nullptr) {
Array<word> words = newArray<word>(segment.size); // May need to lazily read more data.
const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end());
if (!inputFile->read(segment.offset, words.begin(), words.size() * sizeof(word))) { if (readPos < segmentEnd) {
// There was an error but no exception was thrown, so we're supposed to plod along with // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
// default values. Discard the broken stream. // valid.
inputFile = nullptr; const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
} else { readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos);
segment.words = move(words);
} }
} }
return segment.words.asPtr(); return segment;
} }
// ------------------------------------------------------------------- // -------------------------------------------------------------------
...@@ -316,27 +239,45 @@ void writeMessage(OutputStream& output, ArrayPtr<const ArrayPtr<const word>> seg ...@@ -316,27 +239,45 @@ void writeMessage(OutputStream& output, ArrayPtr<const ArrayPtr<const word>> seg
table[segments.size() + 1].set(0); table[segments.size() + 1].set(0);
} }
output.write(table, sizeof(table)); ArrayPtr<const byte> pieces[segments.size() + 1];
pieces[0] = arrayPtr(reinterpret_cast<byte*>(table), sizeof(table));
for (auto& segment: segments) { for (uint i = 0; i < segments.size(); i++) {
output.write(segment.begin(), segment.size() * sizeof(word)); pieces[i + 1] = arrayPtr(reinterpret_cast<const byte*>(segments[i].begin()),
reinterpret_cast<const byte*>(segments[i].end()));
} }
output.write(arrayPtr(pieces, segments.size() + 1));
} }
// ======================================================================================= // =======================================================================================
class OsException: public std::exception { class OsException: public std::exception {
public: public:
OsException(int error): error(error) {} OsException(const char* function, int error) {
char buffer[256];
message = function;
message += ": ";
message.append(strerror_r(error, buffer, sizeof(buffer)));
}
~OsException() noexcept {} ~OsException() noexcept {}
const char* what() const noexcept override { const char* what() const noexcept override {
// TODO: Use strerror_r or whatever for thread-safety. Ugh. return message.c_str();
return strerror(error);
} }
private: private:
int error; std::string message;
};
class PrematureEofException: public std::exception {
public:
PrematureEofException() {}
~PrematureEofException() noexcept {}
const char* what() const noexcept override {
return "Stream ended prematurely.";
}
}; };
AutoCloseFd::~AutoCloseFd() { AutoCloseFd::~AutoCloseFd() {
...@@ -344,78 +285,95 @@ AutoCloseFd::~AutoCloseFd() { ...@@ -344,78 +285,95 @@ AutoCloseFd::~AutoCloseFd() {
if (std::uncaught_exception()) { if (std::uncaught_exception()) {
// TODO: Devise some way to report secondary errors during unwind. // TODO: Devise some way to report secondary errors during unwind.
} else { } else {
throw OsException(errno); throw OsException("close", errno);
} }
} }
} }
FdInputStream::~FdInputStream() {} FdInputStream::~FdInputStream() {}
bool FdInputStream::read(void* buffer, size_t size) { size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
char* pos = reinterpret_cast<char*>(buffer); byte* pos = reinterpret_cast<byte*>(buffer);
byte* min = pos + minBytes;
byte* max = pos + maxBytes;
while (size > 0) { while (pos < min) {
ssize_t n = ::read(fd, pos, size); ssize_t n = ::read(fd, pos, max - pos);
if (n <= 0) { if (n <= 0) {
if (n < 0) { if (n < 0) {
// TODO: Use strerror_r or whatever for thread-safety. Ugh. int error = errno;
errorReporter->reportError(strerror(errno)); if (error == EINTR) {
continue;
} else {
throw OsException("read", error);
}
} else if (n == 0) { } else if (n == 0) {
errorReporter->reportError("Stream ended prematurely."); throw PrematureEofException();
} }
return false; return false;
} }
pos += n; pos += n;
size -= n;
} }
return true; return pos - reinterpret_cast<byte*>(buffer);
} }
FdInputFile::~FdInputFile() {} FdOutputStream::~FdOutputStream() {}
bool FdInputFile::read(size_t offset, void* buffer, size_t size) { void FdOutputStream::write(const void* buffer, size_t size) {
char* pos = reinterpret_cast<char*>(buffer); const char* pos = reinterpret_cast<const char*>(buffer);
offset += this->offset;
while (size > 0) { while (size > 0) {
ssize_t n = ::pread(fd, pos, size, offset); ssize_t n = ::write(fd, pos, size);
if (n <= 0) { if (n <= 0) {
if (n < 0) { CAPNPROTO_ASSERT(n < 0, "write() returned zero.");
// TODO: Use strerror_r or whatever for thread-safety. Ugh. throw OsException("write", errno);
errorReporter->reportError(strerror(errno));
} else if (n == 0) {
errorReporter->reportError("Stream ended prematurely.");
}
return false;
} }
pos += n; pos += n;
offset += n;
size -= n; size -= n;
} }
return true;
} }
FdOutputStream::~FdOutputStream() {} void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
struct iovec iov[pieces.size()];
for (uint i = 0; i < pieces.size(); i++) {
// writev() interface is not const-correct. :(
iov[i].iov_base = const_cast<byte*>(pieces[i].begin());
iov[i].iov_len = pieces[i].size();
}
void FdOutputStream::write(const void* buffer, size_t size) { struct iovec* current = iov;
const char* pos = reinterpret_cast<const char*>(buffer); struct iovec* end = iov + pieces.size();
// Make sure we don't do anything on an empty write.
while (current < end && current->iov_len == 0) {
++current;
}
while (current < end) {
ssize_t n = ::writev(fd, iov, end - current);
while (size > 0) {
ssize_t n = ::write(fd, pos, size);
if (n <= 0) { if (n <= 0) {
throw OsException(n == 0 ? EIO : errno); if (n <= 0) {
CAPNPROTO_ASSERT(n < 0, "write() returned zero.");
throw OsException("writev", errno);
}
}
while (static_cast<size_t>(n) >= current->iov_len) {
n -= current->iov_len;
++current;
}
if (n > 0) {
current->iov_base = reinterpret_cast<byte*>(current->iov_base) + n;
current->iov_len -= n;
} }
pos += n;
size -= n;
} }
} }
StreamFdMessageReader::~StreamFdMessageReader() {} StreamFdMessageReader::~StreamFdMessageReader() {}
FileFdMessageReader::~FileFdMessageReader() {}
void writeMessageToFd(int fd, ArrayPtr<const ArrayPtr<const word>> segments) { void writeMessageToFd(int fd, ArrayPtr<const ArrayPtr<const word>> segments) {
FdOutputStream stream(fd); FdOutputStream stream(fd);
......
...@@ -75,18 +75,25 @@ class InputStream { ...@@ -75,18 +75,25 @@ class InputStream {
public: public:
virtual ~InputStream(); virtual ~InputStream();
virtual bool read(void* buffer, size_t size) = 0; virtual size_t read(void* buffer, size_t minBytes, size_t maxBytes) = 0;
// Always reads the full size requested. Returns true if successful. May throw an exception // Reads at least minBytes and at most maxBytes, copying them into the given buffer. Returns
// on failure, or report the error through some side channel and return false. // the size read. Throws an exception on errors.
}; //
// maxBytes is the number of bytes the caller really wants, but minBytes is the minimum amount
class InputFile { // needed by the caller before it can start doing useful processing. If the stream returns less
public: // than maxBytes, the caller will usually call read() again later to get the rest. Returning
virtual ~InputFile(); // less than maxBytes is useful when it makes sense for the caller to parallelize processing
// with I/O.
virtual bool read(size_t offset, void* buffer, size_t size) = 0; //
// Always reads the full size requested. Returns true if successful. May throw an exception // Cap'n Proto never asks for more bytes than it knows are part of the message. Therefore, if
// on failure, or report the error through some side channel and return false. // the InputStream happens to know that the stream will never reach maxBytes -- even if it has
// reached minBytes -- it should throw an exception to avoid wasting time processing an incomplete
// message. If it can't even reach minBytes, it MUST throw an exception, as the caller is not
// expected to understand how to deal with partial reads.
virtual void skip(size_t bytes);
// Skips past the given number of bytes, discarding them. The default implementation read()s
// into a scratch buffer.
}; };
class OutputStream { class OutputStream {
...@@ -94,85 +101,34 @@ public: ...@@ -94,85 +101,34 @@ public:
virtual ~OutputStream(); virtual ~OutputStream();
virtual void write(const void* buffer, size_t size) = 0; virtual void write(const void* buffer, size_t size) = 0;
// Throws exception on error, or reports errors via some side channel and returns. // Always writes the full size. Throws exception on error.
};
enum class InputStrategy {
EAGER,
// Read the whole message into RAM upfront, in the MessageReader constructor. When reading from
// an InputStream, the stream will then be positioned at the byte immediately after the end of
// the message, and will not be accessed again.
LAZY,
// Read segments of the message into RAM as needed while the message structure is being traversed.
// When reading from an InputStream, segments must be read in order, so segments up to the
// required segment will also be read. No guarantee is made about the position of the InputStream
// after reading. When using an InputFile, only the exact segments desired are read.
EAGER_WAIT_FOR_READ_NEXT, virtual void write(ArrayPtr<const ArrayPtr<const byte>> pieces);
// Like EAGER but don't read the first mesasge until readNext() is called the first time. // Equivalent to write()ing each byte array in sequence, which is what the default implementation
// does. Override if you can do something better, e.g. use writev() to do the write in a single
LAZY_WAIT_FOR_READ_NEXT, // syscall.
// Like LAZY but don't read the first mesasge until readNext() is called the first time.
}; };
class InputStreamMessageReader: public MessageReader { class InputStreamMessageReader: public MessageReader {
public: public:
InputStreamMessageReader(InputStream* inputStream, InputStreamMessageReader(InputStream& inputStream,
ReaderOptions options = ReaderOptions(), ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER); ArrayPtr<word> scratchSpace = nullptr);
~InputStreamMessageReader();
void readNext();
// Progress to the next message in the input stream, reusing the same memory if possible.
// Calling this invalidates any Readers currently pointing into this message.
// implements MessageReader ---------------------------------------- // implements MessageReader ----------------------------------------
ArrayPtr<const word> getSegment(uint id) override; ArrayPtr<const word> getSegment(uint id) override;
private: private:
InputStream* inputStream; InputStream& inputStream;
InputStrategy inputStrategy; byte* readPos;
uint segmentsReadSoFar;
struct LazySegment {
uint size;
Array<word> words;
// words may be larger than the desired size in the case where space is being reused from a
// previous read.
inline LazySegment(): size(0), words(nullptr) {}
};
// Optimize for single-segment case. // Optimize for single-segment case.
LazySegment segment0; ArrayPtr<const word> segment0;
Array<LazySegment> moreSegments; Array<ArrayPtr<const word>> moreSegments;
void readNextInternal();
};
class InputFileMessageReader: public MessageReader {
public:
InputFileMessageReader(InputFile* inputFile,
ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER);
// implements MessageReader ----------------------------------------
ArrayPtr<const word> getSegment(uint id) override;
private:
InputFile* inputFile;
struct LazySegment {
uint size;
size_t offset;
Array<word> words; // null until actually read
inline LazySegment(): size(0), offset(0), words(nullptr) {}
};
// Optimize for single-segment case. Array<word> ownedSpace;
LazySegment segment0; // Only if scratchSpace wasn't big enough.
Array<LazySegment> moreSegments;
}; };
void writeMessage(OutputStream& output, MessageBuilder& builder); void writeMessage(OutputStream& output, MessageBuilder& builder);
...@@ -218,39 +174,15 @@ class FdInputStream: public InputStream { ...@@ -218,39 +174,15 @@ class FdInputStream: public InputStream {
// An InputStream wrapping a file descriptor. // An InputStream wrapping a file descriptor.
public: public:
FdInputStream(int fd, ErrorReporter* errorReporter = getThrowingErrorReporter()) FdInputStream(int fd): fd(fd) {};
: fd(fd), errorReporter(errorReporter) {}; FdInputStream(AutoCloseFd fd): fd(fd), autoclose(move(fd)) {}
FdInputStream(AutoCloseFd fd, ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), autoclose(move(fd)), errorReporter(errorReporter) {}
~FdInputStream(); ~FdInputStream();
bool read(void* buffer, size_t size) override; size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
private: private:
int fd; int fd;
AutoCloseFd autoclose; AutoCloseFd autoclose;
ErrorReporter* errorReporter;
};
class FdInputFile: public InputFile {
// An InputFile wrapping a file descriptor. The file descriptor must be seekable. This
// implementation uses pread(), so the stream pointer will not be modified.
public:
FdInputFile(int fd, size_t offset, ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), offset(offset), errorReporter(errorReporter) {};
FdInputFile(AutoCloseFd fd, size_t offset,
ErrorReporter* errorReporter = getThrowingErrorReporter())
: fd(fd), autoclose(move(fd)), offset(offset), errorReporter(errorReporter) {}
~FdInputFile();
bool read(size_t offset, void* buffer, size_t size) override;
private:
int fd;
AutoCloseFd autoclose;
size_t offset;
ErrorReporter* errorReporter;
}; };
class FdOutputStream: public OutputStream { class FdOutputStream: public OutputStream {
...@@ -262,6 +194,7 @@ public: ...@@ -262,6 +194,7 @@ public:
~FdOutputStream(); ~FdOutputStream();
void write(const void* buffer, size_t size) override; void write(const void* buffer, size_t size) override;
void write(ArrayPtr<const ArrayPtr<const byte>> pieces) override;
private: private:
int fd; int fd;
...@@ -274,8 +207,8 @@ class StreamFdMessageReader: private FdInputStream, public InputStreamMessageRea ...@@ -274,8 +207,8 @@ class StreamFdMessageReader: private FdInputStream, public InputStreamMessageRea
public: public:
StreamFdMessageReader(int fd, ReaderOptions options = ReaderOptions(), StreamFdMessageReader(int fd, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::EAGER) ArrayPtr<word> scratchSpace = nullptr)
: FdInputStream(fd), InputStreamMessageReader(this, options, inputStrategy) {} : FdInputStream(fd), InputStreamMessageReader(*this, options, scratchSpace) {}
// Read message from a file descriptor, without taking ownership of the descriptor. // Read message from a file descriptor, without taking ownership of the descriptor.
// //
// Since this version implies that the caller intends to read more data from the fd later on, the // Since this version implies that the caller intends to read more data from the fd later on, the
...@@ -283,8 +216,8 @@ public: ...@@ -283,8 +216,8 @@ public:
// deterministically positioned just past the end of the message. // deterministically positioned just past the end of the message.
StreamFdMessageReader(AutoCloseFd fd, ReaderOptions options = ReaderOptions(), StreamFdMessageReader(AutoCloseFd fd, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY) ArrayPtr<word> scratchSpace = nullptr)
: FdInputStream(move(fd)), InputStreamMessageReader(this, options, inputStrategy) {} : FdInputStream(move(fd)), InputStreamMessageReader(*this, options, scratchSpace) {}
// Read a message from a file descriptor, taking ownership of the descriptor. // Read a message from a file descriptor, taking ownership of the descriptor.
// //
// Since this version implies that the caller does not intend to read any more data from the fd, // Since this version implies that the caller does not intend to read any more data from the fd,
...@@ -293,31 +226,6 @@ public: ...@@ -293,31 +226,6 @@ public:
~StreamFdMessageReader(); ~StreamFdMessageReader();
}; };
class FileFdMessageReader: private FdInputFile, public InputFileMessageReader {
// A MessageReader that reads from a seekable file descriptor, e.g. disk files. For non-seekable
// file descriptors, use FdStreamMessageReader. This implementation uses pread(), so the file
// descriptor's stream pointer will not be modified.
public:
FileFdMessageReader(int fd, size_t offset, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY)
: FdInputFile(fd, offset, options.errorReporter),
InputFileMessageReader(this, options, inputStrategy) {}
// Read message from a file descriptor, without taking ownership of the descriptor.
//
// All reads use pread(), so the file descriptor's stream pointer will not be modified.
FileFdMessageReader(AutoCloseFd fd, size_t offset, ReaderOptions options = ReaderOptions(),
InputStrategy inputStrategy = InputStrategy::LAZY)
: FdInputFile(move(fd), offset, options.errorReporter),
InputFileMessageReader(this, options, inputStrategy) {}
// Read a message from a file descriptor, taking ownership of the descriptor.
//
// All reads use pread(), so the file descriptor's stream pointer will not be modified.
~FileFdMessageReader();
};
void writeMessageToFd(int fd, MessageBuilder& builder); void writeMessageToFd(int fd, MessageBuilder& builder);
// Write the message to the given file descriptor. // Write the message to the given file descriptor.
// //
......
...@@ -80,6 +80,8 @@ public: ...@@ -80,6 +80,8 @@ public:
inline T* begin() const { return ptr; } inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; } inline T* end() const { return ptr + size_; }
inline T& front() const { return *ptr; }
inline T& back() const { return *(ptr + size_ - 1); }
inline ArrayPtr slice(size_t start, size_t end) { inline ArrayPtr slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice()."); CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds ArrayPtr::slice().");
...@@ -137,6 +139,8 @@ public: ...@@ -137,6 +139,8 @@ public:
inline T* begin() const { return ptr; } inline T* begin() const { return ptr; }
inline T* end() const { return ptr + size_; } inline T* end() const { return ptr + size_; }
inline T& front() const { return *ptr; }
inline T& back() const { return *(ptr + size_ - 1); }
inline ArrayPtr<T> slice(size_t start, size_t end) { inline ArrayPtr<T> slice(size_t start, size_t end) {
CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds Array::slice()."); CAPNPROTO_DEBUG_ASSERT(start <= end && end <= size_, "Out-of-bounds Array::slice().");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment