Commit 227a4911 authored by Kenton Varda's avatar Kenton Varda

Refactor compiler binary into a multi-tool with sub-commands like 'compile' and…

Refactor compiler binary into a multi-tool with sub-commands like 'compile' and 'decode'.  The latter dynamically decodes binary input to text based on a provided schema file.
parent 7c9309be
......@@ -38,6 +38,8 @@
#include <sys/types.h>
#include <sys/wait.h>
#include <capnp/serialize.h>
#include <capnp/serialize-packed.h>
#include <limits>
namespace capnp {
namespace compiler {
......@@ -63,43 +65,110 @@ public:
}
};
static const char VERSION_STRING[] = "Cap'n Proto version 0.2";
class CompilerMain final: public GlobalErrorReporter {
public:
explicit CompilerMain(kj::ProcessContext& context)
: context(context), loader(*this) {}
kj::MainFunc getMain() {
return kj::MainBuilder(
context, "Cap'n Proto compiler version 0.2",
if (context.getProgramName().endsWith("capnpc")) {
kj::MainBuilder builder(context, VERSION_STRING,
"Compiles Cap'n Proto schema files and generates corresponding source code in one or "
"more languages.");
addGlobalOptions(builder);
addCompileOptions(builder);
builder.addOption({'i', "generate-id"}, KJ_BIND_METHOD(*this, generateId),
"Generate a new 64-bit unique ID for use in a Cap'n Proto schema.");
return builder.build();
} else {
kj::MainBuilder builder(context, VERSION_STRING,
"Command-line tool for Cap'n Proto development and debugging.");
builder.addSubCommand("compile", KJ_BIND_METHOD(*this, getCompileMain),
"Generate source code from schema files.")
.addSubCommand("id", KJ_BIND_METHOD(*this, getGenIdMain),
"Generate a new unique ID.")
.addSubCommand("decode", KJ_BIND_METHOD(*this, getDecodeMain),
"Decode binary Cap'n Proto message to text.");
// TODO(someday): "encode" -- requires the ability to parse text format.
addGlobalOptions(builder);
return builder.build();
}
}
kj::MainFunc getCompileMain() {
kj::MainBuilder builder(context, VERSION_STRING,
"Compiles Cap'n Proto schema files and generates corresponding source code in one or "
"more languages.")
.addOptionWithArg({'I', "import-path"}, KJ_BIND_METHOD(*this, addImportPath), "<dir>",
"more languages.");
addGlobalOptions(builder);
addCompileOptions(builder);
return builder.build();
}
kj::MainFunc getGenIdMain() {
return kj::MainBuilder(
context, "Cap'n Proto multi-tool 0.2",
"Generates a new 64-bit unique ID for use in a Cap'n Proto schema.")
.callAfterParsing(KJ_BIND_METHOD(*this, generateId))
.build();
}
kj::MainFunc getDecodeMain() {
// Only parse the schemas we actually need for decoding.
compileMode = Compiler::LAZY;
kj::MainBuilder builder(context, VERSION_STRING,
"Decodes one or more encoded Cap'n Proto messages as text. The messages have root "
"type <type> defined in <schema-file>. Messages are read from standard input and "
"by default are expected to be in standard Cap'n Proto serialization format.");
addGlobalOptions(builder);
builder.addOption({'f', "flat"}, KJ_BIND_METHOD(*this, codeFlat),
"Interpret the input as one large single-segment message rather than a "
"stream in standard serialization format.")
.addOption({'p', "packed"}, KJ_BIND_METHOD(*this, codePacked),
"Expect the input to be packed using standard Cap'n Proto packing, which "
"deflates zero-valued bytes.")
.addOption({"short"}, KJ_BIND_METHOD(*this, printShort),
"Print in short (non-pretty) format. Each message will be printed on one "
"line, without using whitespace to improve readability.")
.expectArg("<schema-file>", KJ_BIND_METHOD(*this, addSource))
.expectArg("<type>", KJ_BIND_METHOD(*this, setRootType))
.callAfterParsing(KJ_BIND_METHOD(*this, decode));
return builder.build();
}
void addGlobalOptions(kj::MainBuilder& builder) {
builder.addOptionWithArg({'I', "import-path"}, KJ_BIND_METHOD(*this, addImportPath), "<dir>",
"Add <dir> to the list of directories searched for non-relative "
"imports (ones that start with a '/').")
.addOption({"no-standard-import"}, KJ_BIND_METHOD(*this, noStandardImport),
"Do not add any default import paths; use only those specified by -I. "
"Otherwise, typically /usr/include and /usr/local/include are added by "
"default.")
.addOptionWithArg({'o', "output"}, KJ_BIND_METHOD(*this, addOutput), "<lang>[:<dir>]",
"Generate source code for language <lang> in directory <dir> (default: "
"current directory). <lang> actually specifies a plugin to use. If "
"<lang> is a simple word, the compiler for a plugin called "
"'capnpc-<lang>' in $PATH. If <lang> is a file path containing slashes, "
"it is interpreted as the exact plugin executable file name, and $PATH "
"is not searched.")
"default.");
}
void addCompileOptions(kj::MainBuilder& builder) {
builder.addOptionWithArg({'o', "output"}, KJ_BIND_METHOD(*this, addOutput), "<lang>[:<dir>]",
"Generate source code for language <lang> in directory <dir> "
"(default: current directory). <lang> actually specifies a plugin "
"to use. If <lang> is a simple word, the compiler for a plugin "
"called 'capnpc-<lang>' in $PATH. If <lang> is a file path "
"containing slashes, it is interpreted as the exact plugin "
"executable file name, and $PATH is not searched.")
.addOptionWithArg({"src-prefix"}, KJ_BIND_METHOD(*this, addSourcePrefix), "<prefix>",
"If a file specified for compilation starts with <prefix>, remove "
"the prefix for the purpose of deciding the names of output files. "
"For example, the following command:\n"
" capnp --src-prefix=foo/bar -oc++:corge foo/bar/baz/qux.capnp\n"
"would generate the files corge/baz/qux.capnp.{h,c++}.")
.addOption({'i', "generate-id"}, KJ_BIND_METHOD(*this, generateId),
"Generate a new 64-bit unique ID for use in a Cap'n Proto schema.")
.expectOneOrMoreArgs("source", KJ_BIND_METHOD(*this, addSource))
.callAfterParsing(KJ_BIND_METHOD(*this, generateOutput))
.build();
.expectOneOrMoreArgs("<source>", KJ_BIND_METHOD(*this, addSource))
.callAfterParsing(KJ_BIND_METHOD(*this, generateOutput));
}
// =====================================================================================
// shared options
kj::MainBuilder::Validity addImportPath(kj::StringPtr path) {
loader.addImportPath(kj::heapString(path));
return true;
......@@ -110,34 +179,6 @@ public:
return true;
}
kj::MainBuilder::Validity addOutput(kj::StringPtr spec) {
KJ_IF_MAYBE(split, spec.findFirst(':')) {
kj::StringPtr dir = spec.slice(*split + 1);
struct stat stats;
if (stat(dir.cStr(), &stats) < 0 || !S_ISDIR(stats.st_mode)) {
return "output location is inaccessible or is not a directory";
}
outputs.add(OutputDirective { spec.slice(0, *split), dir });
} else {
outputs.add(OutputDirective { spec.asArray(), nullptr });
}
return true;
}
kj::MainBuilder::Validity addSourcePrefix(kj::StringPtr prefix) {
if (prefix.endsWith("/")) {
sourcePrefixes.add(kj::heapString(prefix));
} else {
sourcePrefixes.add(kj::str(prefix, '/'));
}
return true;
}
kj::MainBuilder::Validity generateId() {
context.exitInfo(kj::str("@0x", kj::hex(generateRandomId())));
}
kj::MainBuilder::Validity addSource(kj::StringPtr file) {
if (addStandardImportPaths) {
loader.addImportPath(kj::heapString("/usr/local/include"));
......@@ -162,6 +203,40 @@ public:
return true;
}
// =====================================================================================
// "id" command
kj::MainBuilder::Validity generateId() {
context.exitInfo(kj::str("@0x", kj::hex(generateRandomId())));
}
// =====================================================================================
// "compile" command
kj::MainBuilder::Validity addOutput(kj::StringPtr spec) {
KJ_IF_MAYBE(split, spec.findFirst(':')) {
kj::StringPtr dir = spec.slice(*split + 1);
struct stat stats;
if (stat(dir.cStr(), &stats) < 0 || !S_ISDIR(stats.st_mode)) {
return "output location is inaccessible or is not a directory";
}
outputs.add(OutputDirective { spec.slice(0, *split), dir });
} else {
outputs.add(OutputDirective { spec.asArray(), nullptr });
}
return true;
}
kj::MainBuilder::Validity addSourcePrefix(kj::StringPtr prefix) {
if (prefix.endsWith("/")) {
sourcePrefixes.add(kj::heapString(prefix));
} else {
sourcePrefixes.add(kj::str(prefix, '/'));
}
return true;
}
kj::MainBuilder::Validity generateOutput() {
if (hadErrors()) {
// Skip output if we had any errors.
......@@ -242,6 +317,147 @@ public:
return true;
}
// =====================================================================================
// "decode" command
kj::MainBuilder::Validity codeFlat() {
if (packed) return "cannot be used with --packed";
flat = true;
return true;
}
kj::MainBuilder::Validity codePacked() {
if (flat) return "cannot be used with --flat";
packed = true;
return true;
}
kj::MainBuilder::Validity printShort() {
pretty = false;
return true;
}
kj::MainBuilder::Validity setRootType(kj::StringPtr type) {
KJ_ASSERT(sourceIds.size() == 1);
uint64_t id = sourceIds[0];
while (type.size() > 0) {
kj::String temp;
kj::StringPtr part;
KJ_IF_MAYBE(dotpos, type.findFirst('.')) {
temp = kj::heapString(type.slice(0, *dotpos));
part = temp;
type = type.slice(*dotpos + 1);
} else {
part = type;
type = nullptr;
}
KJ_IF_MAYBE(childId, compiler.lookup(id, part)) {
id = *childId;
} else {
return "no such type";
}
}
Schema schema = compiler.getLoader().get(id);
if (schema.getProto().getBody().which() != schema::Node::Body::STRUCT_NODE) {
return "not a struct type";
}
rootType = schema.asStruct();
return true;
}
kj::MainBuilder::Validity decode() {
kj::FdInputStream rawInput(STDIN_FILENO);
kj::BufferedInputStreamWrapper input(rawInput);
if (flat) {
// Read in the whole input to decode as one segment.
kj::Array<word> words;
{
kj::Vector<byte> allBytes;
for (;;) {
auto buffer = input.tryGetReadBuffer();
if (buffer.size() == 0) break;
allBytes.addAll(buffer);
input.skip(buffer.size());
}
// Technically we don't know if the bytes are aligned so we'd better copy them to a new
// array. Note that if we have a non-whole number of words we chop off the straggler bytes.
// This is fine because if those bytes are actually part of the message we will hit an error
// later and if they are not then who cares?
words = kj::heapArray<word>(allBytes.size() / sizeof(word));
memcpy(words.begin(), allBytes.begin(), words.size() * sizeof(word));
}
kj::ArrayPtr<const word> segments = words;
decodeInner<SegmentArrayMessageReader>(arrayPtr(&segments, 1));
} else {
while (input.tryGetReadBuffer().size() > 0) {
if (packed) {
decodeInner<PackedMessageReader>(input);
} else {
decodeInner<InputStreamMessageReader>(input);
}
}
}
return true;
}
private:
struct ParseErrorCatcher: public kj::ExceptionCallback {
void onRecoverableException(kj::Exception&& e) {
// Only capture the first exception, on the assumption that later exceptions are probably
// just cascading problems.
if (exception == nullptr) {
exception = kj::mv(e);
}
}
kj::Maybe<kj::Exception> exception;
};
template <typename MessageReaderType, typename Input>
void decodeInner(Input&& input) {
// Since this is a debug tool, lift the usual security limits. Worse case is the process
// crashes or has to be killed.
ReaderOptions options;
options.nestingLimit = std::numeric_limits<decltype(options.nestingLimit)>::max() >> 1;
options.traversalLimitInWords =
std::numeric_limits<decltype(options.traversalLimitInWords)>::max();
MessageReaderType reader(input, options);
auto root = reader.template getRoot<DynamicStruct>(rootType);
kj::String text;
kj::Maybe<kj::Exception> exception;
{
ParseErrorCatcher catcher;
if (pretty) {
text = prettyPrint(root);
} else {
text = kj::str(root);
}
exception = kj::mv(catcher.exception);
}
kj::ArrayPtr<const byte> pieces[2];
pieces[0] = kj::arrayPtr(reinterpret_cast<const byte*>(text.begin()), text.size());
pieces[1] = kj::arrayPtr(reinterpret_cast<const byte*>("\n"), 1);
kj::FdOutputStream(STDOUT_FILENO).write(kj::arrayPtr(pieces, KJ_ARRAY_SIZE(pieces)));
KJ_IF_MAYBE(e, exception) {
context.error(kj::str("*** error in previous message ***\n", *e, "\n*** end error ***"));
}
}
public:
// =====================================================================================
void addError(kj::StringPtr file, SourcePos start, SourcePos end,
kj::StringPtr message) const override {
kj::String wholeMessage;
......@@ -270,10 +486,17 @@ private:
kj::ProcessContext& context;
ModuleLoader loader;
Compiler compiler;
Compiler::Mode compileMode = Compiler::EAGER;
kj::Vector<kj::String> sourcePrefixes;
bool addStandardImportPaths = true;
bool flat = false;
bool packed = false;
bool pretty = true;
StructSchema rootType;
// For the "decode" command.
kj::Vector<uint64_t> sourceIds;
struct OutputDirective {
......
......@@ -214,6 +214,7 @@ public:
virtual ~Impl();
uint64_t add(const Module& module, Mode mode) const;
kj::Maybe<uint64_t> lookup(uint64_t parent, kj::StringPtr childName) const;
const CompiledModule& add(const Module& parsedModule) const;
struct Workspace {
......@@ -785,6 +786,15 @@ uint64_t Compiler::Impl::add(const Module& module, Mode mode) const {
return node.getId();
}
kj::Maybe<uint64_t> Compiler::Impl::lookup(uint64_t parent, kj::StringPtr childName) const {
// We know this won't use the workspace, so we need not lock it.
KJ_IF_MAYBE(parentNode, findNode(parent)) {
return parentNode->lookupMember(childName).map([](const Node& n) { return n.getId(); });
} else {
KJ_FAIL_REQUIRE("lookup()s parameter 'parent' must be a known ID.", parent);
}
}
void Compiler::Impl::load(const SchemaLoader& loader, uint64_t id) const {
KJ_IF_MAYBE(node, findNode(id)) {
if (&loader == &finalLoader) {
......@@ -806,6 +816,10 @@ uint64_t Compiler::add(const Module& module, Mode mode) const {
return impl->add(module, mode);
}
kj::Maybe<uint64_t> Compiler::lookup(uint64_t parent, kj::StringPtr childName) const {
return impl->lookup(parent, childName);
}
const SchemaLoader& Compiler::getLoader() const {
return impl->getFinalLoader();
}
......
......@@ -82,6 +82,11 @@ public:
// errors while compiling (reported via `module.addError()`), then the SchemaLoader may behave as
// if the node doesn't exist, or may return an invalid partial Schema.
kj::Maybe<uint64_t> lookup(uint64_t parent, kj::StringPtr childName) const;
// Given the type ID of a schema node, find the ID of a node nested within it, without actually
// building either node. Throws an exception if the parent ID is not recognized; returns null
// if the parent has no child of the given name.
const SchemaLoader& getLoader() const;
// Get a SchemaLoader backed by this compiler. Schema nodes will be lazily constructed as you
// traverse them using this loader.
......
......@@ -59,7 +59,7 @@ public:
data.append(reinterpret_cast<const char*>(buffer), size);
}
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override {
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_ASSERT(maxBytes <= data.size() - readPos, "Overran end of stream.");
size_t amount = std::min(maxBytes, std::max(minBytes, preferredReadSize));
memcpy(buffer, data.data() + readPos, amount);
......@@ -72,7 +72,7 @@ public:
readPos += bytes;
}
kj::ArrayPtr<const byte> getReadBuffer() override {
kj::ArrayPtr<const byte> tryGetReadBuffer() override {
size_t amount = std::min(data.size() - readPos, preferredReadSize);
return kj::arrayPtr(reinterpret_cast<const byte*>(data.data() + readPos), amount);
}
......
......@@ -33,7 +33,7 @@ namespace _ { // private
PackedInputStream::PackedInputStream(kj::BufferedInputStream& inner): inner(inner) {}
PackedInputStream::~PackedInputStream() noexcept(false) {}
size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
if (maxBytes == 0) {
return 0;
}
......@@ -46,8 +46,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes;
kj::ArrayPtr<const byte> buffer = inner.getReadBuffer();
KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") {
return minBytes; // garbage
if (buffer.size() == 0) {
return 0;
}
const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
......@@ -55,7 +55,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
inner.skip(buffer.size()); \
buffer = inner.getReadBuffer(); \
KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \
return minBytes; /* garbage */ \
return out - reinterpret_cast<uint8_t*>(dst); \
} \
in = reinterpret_cast<const uint8_t*>(buffer.begin())
......@@ -127,7 +127,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
return out - reinterpret_cast<uint8_t*>(dst);
}
memset(out, 0, runLength);
out += runLength;
......@@ -139,7 +139,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
return out - reinterpret_cast<uint8_t*>(dst);
}
uint inRemaining = BUFFER_REMAINING;
......
......@@ -40,7 +40,7 @@ public:
~PackedInputStream() noexcept(false);
// implements InputStream ------------------------------------------
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
void skip(size_t bytes) override;
private:
......
......@@ -94,7 +94,7 @@ public:
data.append(reinterpret_cast<const char*>(buffer), size);
}
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override {
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_ASSERT(maxBytes <= data.size() - readPos, "Overran end of stream.");
size_t amount = std::min(maxBytes, std::max(minBytes, preferredReadSize));
memcpy(buffer, data.data() + readPos, amount);
......@@ -107,7 +107,7 @@ public:
readPos += bytes;
}
kj::ArrayPtr<const byte> getReadBuffer() override {
kj::ArrayPtr<const byte> tryGetReadBuffer() override {
size_t amount = std::min(data.size() - readPos, preferredReadSize);
return kj::arrayPtr(reinterpret_cast<const byte*>(data.data() + readPos), amount);
}
......
......@@ -36,6 +36,10 @@ public:
: inputStream(inputStream) {}
inline ~InputStreamSnappySource() noexcept {}
bool atEnd() {
return inputStream.getReadBuffer().size() == 0;
}
// implements snappy::Source ---------------------------------------
size_t Available() const override {
......@@ -68,7 +72,7 @@ SnappyInputStream::SnappyInputStream(BufferedInputStream& inner, kj::ArrayPtr<by
SnappyInputStream::~SnappyInputStream() noexcept(false) {}
kj::ArrayPtr<const byte> SnappyInputStream::getReadBuffer() {
kj::ArrayPtr<const byte> SnappyInputStream::tryGetReadBuffer() {
if (bufferAvailable.size() == 0) {
refill();
}
......@@ -76,44 +80,53 @@ kj::ArrayPtr<const byte> SnappyInputStream::getReadBuffer() {
return bufferAvailable;
}
size_t SnappyInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t SnappyInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
size_t total = 0;
while (minBytes > bufferAvailable.size()) {
memcpy(dst, bufferAvailable.begin(), bufferAvailable.size());
dst = reinterpret_cast<byte*>(dst) + bufferAvailable.size();
total += bufferAvailable.size();
minBytes -= bufferAvailable.size();
maxBytes -= bufferAvailable.size();
refill();
if (!refill()) {
return total;
}
}
// Serve from current buffer.
size_t n = std::min(bufferAvailable.size(), maxBytes);
memcpy(dst, bufferAvailable.begin(), n);
bufferAvailable = bufferAvailable.slice(n, bufferAvailable.size());
return n;
return total + n;
}
void SnappyInputStream::skip(size_t bytes) {
while (bytes > bufferAvailable.size()) {
bytes -= bufferAvailable.size();
refill();
KJ_REQUIRE(refill(), "Premature EOF");
}
bufferAvailable = bufferAvailable.slice(bytes, bufferAvailable.size());
}
void SnappyInputStream::refill() {
bool SnappyInputStream::refill() {
uint32_t length = 0;
InputStreamSnappySource snappySource(inner);
if (snappySource.atEnd()) {
return false;
}
KJ_REQUIRE(
snappy::RawUncompress(
&snappySource, reinterpret_cast<char*>(buffer.begin()), buffer.size(), &length),
"Snappy decompression failed.") {
length = 1; // garbage
break;
return false;
}
bufferAvailable = buffer.slice(0, length);
return true;
}
// =======================================================================================
......
......@@ -39,8 +39,8 @@ public:
~SnappyInputStream() noexcept(false);
// implements BufferedInputStream ----------------------------------
kj::ArrayPtr<const byte> getReadBuffer() override;
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
kj::ArrayPtr<const byte> tryGetReadBuffer() override;
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
void skip(size_t bytes) override;
private:
......@@ -51,7 +51,7 @@ private:
kj::ArrayPtr<byte> buffer;
kj::ArrayPtr<byte> bufferAvailable;
void refill();
bool refill();
};
class SnappyOutputStream: public kj::BufferedOutputStream {
......
......@@ -101,7 +101,7 @@ public:
lazy(lazy) {}
~TestInputStream() {}
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override {
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
KJ_ASSERT(maxBytes <= size_t(end - pos), "Overran end of stream.");
size_t amount = lazy ? minBytes : maxBytes;
memcpy(buffer, pos, amount);
......
......@@ -252,7 +252,7 @@ static void print(std::ostream& os, const DynamicValue::Reader& value,
}
break;
case DynamicValue::OBJECT:
os << "(opaque object)";
os << "<opaque object>";
break;
}
}
......
......@@ -35,6 +35,16 @@ OutputStream::~OutputStream() noexcept(false) {}
BufferedInputStream::~BufferedInputStream() noexcept(false) {}
BufferedOutputStream::~BufferedOutputStream() noexcept(false) {}
size_t InputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
size_t n = tryRead(buffer, minBytes, maxBytes);
KJ_REQUIRE(n >= minBytes, "Premature EOF") {
// Pretend we read zeros from the input.
memset(reinterpret_cast<byte*>(buffer) + n, 0, minBytes - n);
return minBytes;
}
return n;
}
void InputStream::skip(size_t bytes) {
char scratch[8192];
while (bytes > 0) {
......@@ -50,6 +60,12 @@ void OutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
}
}
ArrayPtr<const byte> BufferedInputStream::getReadBuffer() {
auto result = tryGetReadBuffer();
KJ_REQUIRE(result.size() > 0, "Premature EOF");
return result;
}
// =======================================================================================
BufferedInputStreamWrapper::BufferedInputStreamWrapper(InputStream& inner, ArrayPtr<byte> buffer)
......@@ -58,16 +74,16 @@ BufferedInputStreamWrapper::BufferedInputStreamWrapper(InputStream& inner, Array
BufferedInputStreamWrapper::~BufferedInputStreamWrapper() noexcept(false) {}
ArrayPtr<const byte> BufferedInputStreamWrapper::getReadBuffer() {
ArrayPtr<const byte> BufferedInputStreamWrapper::tryGetReadBuffer() {
if (bufferAvailable.size() == 0) {
size_t n = inner.read(buffer.begin(), 1, buffer.size());
size_t n = inner.tryRead(buffer.begin(), 1, buffer.size());
bufferAvailable = buffer.slice(0, n);
}
return bufferAvailable;
}
size_t BufferedInputStreamWrapper::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t BufferedInputStreamWrapper::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
if (minBytes <= bufferAvailable.size()) {
// Serve from current buffer.
size_t n = std::min(bufferAvailable.size(), maxBytes);
......@@ -174,20 +190,15 @@ void BufferedOutputStreamWrapper::write(const void* src, size_t size) {
ArrayInputStream::ArrayInputStream(ArrayPtr<const byte> array): array(array) {}
ArrayInputStream::~ArrayInputStream() noexcept(false) {}
ArrayPtr<const byte> ArrayInputStream::getReadBuffer() {
ArrayPtr<const byte> ArrayInputStream::tryGetReadBuffer() {
return array;
}
size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t ArrayInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
size_t n = std::min(maxBytes, array.size());
size_t result = n;
KJ_REQUIRE(n >= minBytes, "ArrayInputStream ended prematurely.") {
result = minBytes; // garbage
break;
}
memcpy(dst, array.begin(), n);
array = array.slice(n, array.size());
return result;
return n;
}
void ArrayInputStream::skip(size_t bytes) {
......@@ -234,7 +245,7 @@ AutoCloseFd::~AutoCloseFd() noexcept(false) {
FdInputStream::~FdInputStream() noexcept(false) {}
size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
size_t FdInputStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) {
byte* pos = reinterpret_cast<byte*>(buffer);
byte* min = pos + minBytes;
byte* max = pos + maxBytes;
......@@ -242,8 +253,8 @@ size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
while (pos < min) {
ssize_t n;
KJ_SYSCALL(n = ::read(fd, pos, max - pos), fd);
KJ_REQUIRE(n > 0, "Premature EOF") {
return minBytes;
if (n == 0) {
break;
}
pos += n;
}
......
......@@ -38,9 +38,9 @@ class InputStream {
public:
virtual ~InputStream() noexcept(false);
virtual size_t read(void* buffer, size_t minBytes, size_t maxBytes) = 0;
size_t read(void* buffer, size_t minBytes, size_t maxBytes);
// Reads at least minBytes and at most maxBytes, copying them into the given buffer. Returns
// the size read. Throws an exception on errors.
// the size read. Throws an exception on errors. Implemented in terms of tryRead().
//
// maxBytes is the number of bytes the caller really wants, but minBytes is the minimum amount
// needed by the caller before it can start doing useful processing. If the stream returns less
......@@ -48,12 +48,19 @@ public:
// less than maxBytes is useful when it makes sense for the caller to parallelize processing
// with I/O.
//
// Never blocks if minBytes is zero. If minBytes is zero and maxBytes is non-zero, this may
// attempt a non-blocking read or may just return zero. To force a read, use a non-zero minBytes.
// To detect EOF without throwing an exception, use tryRead().
//
// Cap'n Proto never asks for more bytes than it knows are part of the message. Therefore, if
// the InputStream happens to know that the stream will never reach maxBytes -- even if it has
// reached minBytes -- it should throw an exception to avoid wasting time processing an incomplete
// message. If it can't even reach minBytes, it MUST throw an exception, as the caller is not
// expected to understand how to deal with partial reads.
virtual size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) = 0;
// Like read(), but may return fewer than minBytes on EOF.
inline void read(void* buffer, size_t bytes) { read(buffer, bytes, bytes); }
// Convenience method for reading an exact number of bytes.
......@@ -84,10 +91,13 @@ class BufferedInputStream: public InputStream {
public:
virtual ~BufferedInputStream() noexcept(false);
virtual ArrayPtr<const byte> getReadBuffer() = 0;
ArrayPtr<const byte> getReadBuffer();
// Get a direct pointer into the read buffer, which contains the next bytes in the input. If the
// caller consumes any bytes, it should then call skip() to indicate this. This always returns a
// non-empty buffer unless at EOF.
// non-empty buffer or throws an exception. Implemented in terms of tryGetReadBuffer().
virtual ArrayPtr<const byte> tryGetReadBuffer() = 0;
// Like getReadBuffer() but may return an empty buffer on EOF.
};
class BufferedOutputStream: public OutputStream {
......@@ -130,8 +140,8 @@ public:
~BufferedInputStreamWrapper() noexcept(false);
// implements BufferedInputStream ----------------------------------
ArrayPtr<const byte> getReadBuffer() override;
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
ArrayPtr<const byte> tryGetReadBuffer() override;
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
void skip(size_t bytes) override;
private:
......@@ -182,8 +192,8 @@ public:
~ArrayInputStream() noexcept(false);
// implements BufferedInputStream ----------------------------------
ArrayPtr<const byte> getReadBuffer() override;
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
ArrayPtr<const byte> tryGetReadBuffer() override;
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
void skip(size_t bytes) override;
private:
......@@ -250,7 +260,7 @@ public:
KJ_DISALLOW_COPY(FdInputStream);
~FdInputStream() noexcept(false);
size_t read(void* buffer, size_t minBytes, size_t maxBytes) override;
size_t tryRead(void* buffer, size_t minBytes, size_t maxBytes) override;
private:
int fd;
......
......@@ -411,7 +411,7 @@ void MainBuilder::MainImpl::operator()(StringPtr programName, ArrayPtr<const Str
const Impl::Option& option = *iter->second;
if (option.hasArg) {
// Argument expected.
if (j + 1 < params.size()) {
if (j + 1 < param.size()) {
// Rest of flag is argument.
StringPtr arg = param.slice(j + 1);
KJ_IF_MAYBE(error, (*option.funcWithArg)(arg).releaseError()) {
......@@ -439,21 +439,23 @@ void MainBuilder::MainImpl::operator()(StringPtr programName, ArrayPtr<const Str
} else if (!impl->subCommands.empty()) {
// A sub-command name.
auto iter = impl->subCommands.find(param);
if (iter == impl->subCommands.end()) {
if (iter != impl->subCommands.end()) {
MainFunc subMain = iter->second.func();
subMain(str(programName, ' ', param), params.slice(i + 1, params.size()));
return;
} else if (param == "help") {
if (i + 1 < params.size()) {
iter = impl->subCommands.find(params[i + 1]);
if (iter == impl->subCommands.end()) {
usageError(programName, str(params[i + 1], ": unknown command"));
} else {
if (iter != impl->subCommands.end()) {
// Run the sub-command with "--help" as the argument.
MainFunc subMain = iter->second.func();
StringPtr dummyArg = "--help";
subMain(str(programName, ' ', params[i + 1]), arrayPtr(&dummyArg, 1));
return;
} else if (params[i + 1] == "help") {
impl->context.exitInfo("Help, I'm trapped in a help text factory!");
} else {
usageError(programName, str(params[i + 1], ": unknown command"));
}
} else {
printHelp(programName);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment