Commit db505f42 authored by Kenton Varda's avatar Kenton Varda

Refactor assertion macros, specifically with regards to recoverability.

parent 850af66a
......@@ -89,7 +89,9 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
}
void ReaderArena::reportReadLimitReached() {
FAIL_VALIDATE_INPUT("Exceeded message traversal limit. See capnproto::ReaderOptions.");
KJ_FAIL_REQUIRE("Exceeded message traversal limit. See capnproto::ReaderOptions.") {
return;
}
}
// =======================================================================================
......@@ -201,8 +203,10 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
}
void BuilderArena::reportReadLimitReached() {
FAIL_RECOVERABLE_ASSERT(
"Read limit reached for BuilderArena, but it should have been unlimited.") {}
KJ_FAIL_ASSERT(
"Read limit reached for BuilderArena, but it should have been unlimited.") {
return;
}
}
} // namespace internal
......
This diff is collapsed.
This diff is collapsed.
......@@ -50,9 +50,9 @@ internal::StructReader MessageReader::getRootInternal() {
}
internal::SegmentReader* segment = arena()->tryGetSegment(internal::SegmentId(0));
VALIDATE_INPUT(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
"Message did not contain a root pointer.") {
KJ_REQUIRE(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
"Message did not contain a root pointer.") {
return internal::StructReader();
}
......
......@@ -145,9 +145,9 @@ private:
std::map<std::pair<uint, Text::Reader>, uint> members;
#define VALIDATE_SCHEMA(condition, ...) \
VALIDATE_INPUT(condition, ##__VA_ARGS__) { isValid = false; return; }
KJ_REQUIRE(condition, ##__VA_ARGS__) { isValid = false; return; }
#define FAIL_VALIDATE_SCHEMA(...) \
FAIL_VALIDATE_INPUT(__VA_ARGS__) { isValid = false; return; }
KJ_FAIL_REQUIRE(__VA_ARGS__) { isValid = false; return; }
void validate(schema::FileNode::Reader fileNode) {
// Nothing needs validation.
......@@ -472,9 +472,9 @@ private:
Compatibility compatibility;
#define VALIDATE_SCHEMA(condition, ...) \
VALIDATE_INPUT(condition, ##__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
KJ_REQUIRE(condition, ##__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
#define FAIL_VALIDATE_SCHEMA(...) \
FAIL_VALIDATE_INPUT(__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
KJ_FAIL_REQUIRE(__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
void replacementIsNewer() {
switch (compatibility) {
......@@ -934,7 +934,7 @@ private:
schema::Value::Reader replacement) {
// Note that we test default compatibility only after testing type compatibility, and default
// values have already been validated as matching their types, so this should pass.
RECOVERABLE_ASSERT(value.getBody().which() == replacement.getBody().which()) {
KJ_ASSERT(value.getBody().which() == replacement.getBody().which()) {
compatibility = INCOMPATIBLE;
return;
}
......
......@@ -47,7 +47,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes;
kj::ArrayPtr<const byte> buffer = inner.getReadBuffer();
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") {
KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") {
return minBytes; // garbage
}
const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
......@@ -55,7 +55,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
#define REFRESH_BUFFER() \
inner.skip(buffer.size()); \
buffer = inner.getReadBuffer(); \
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") { \
KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \
return minBytes; /* garbage */ \
} \
in = reinterpret_cast<const uint8_t*>(buffer.begin())
......@@ -126,8 +126,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
}
memset(out, 0, runLength);
......@@ -138,8 +138,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
}
......@@ -198,7 +198,7 @@ void PackedInputStream::skip(size_t bytes) {
#define REFRESH_BUFFER() \
inner.skip(buffer.size()); \
buffer = inner.getReadBuffer(); \
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") return; \
KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \
in = reinterpret_cast<const uint8_t*>(buffer.begin())
for (;;) {
......@@ -252,8 +252,7 @@ void PackedInputStream::skip(size_t bytes) {
uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= bytes,
"Packed input did not end cleanly on a segment boundary.") {
KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
return;
}
......@@ -264,8 +263,7 @@ void PackedInputStream::skip(size_t bytes) {
uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= bytes,
"Packed input did not end cleanly on a segment boundary.") {
KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
return;
}
......
......@@ -106,11 +106,12 @@ void SnappyInputStream::skip(size_t bytes) {
void SnappyInputStream::refill() {
uint32_t length = 0;
InputStreamSnappySource snappySource(inner);
VALIDATE_INPUT(
KJ_REQUIRE(
snappy::RawUncompress(
&snappySource, reinterpret_cast<char*>(buffer.begin()), buffer.size(), &length),
"Snappy decompression failed.") {
length = 1; // garbage
break;
}
bufferAvailable = buffer.slice(0, length);
......
......@@ -42,7 +42,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(
uint segmentCount = table[0].get() + 1;
size_t offset = segmentCount / 2u + 1u;
VALIDATE_INPUT(array.size() >= offset, "Message ends prematurely in segment table.") {
KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") {
return;
}
......@@ -52,8 +52,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(
uint segmentSize = table[1].get();
VALIDATE_INPUT(array.size() >= offset + segmentSize,
"Message ends prematurely in first segment.") {
KJ_REQUIRE(array.size() >= offset + segmentSize,
"Message ends prematurely in first segment.") {
return;
}
......@@ -66,7 +66,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(
for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = table[i + 1].get();
VALIDATE_INPUT(array.size() >= offset + segmentSize, "Message ends prematurely.") {
KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") {
moreSegments = nullptr;
return;
}
......@@ -142,9 +142,10 @@ InputStreamMessageReader::InputStreamMessageReader(
size_t totalWords = segment0Size;
// Reject messages with too many segments for security reasons.
VALIDATE_INPUT(segmentCount < 512, "Message has too many segments.") {
KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") {
segmentCount = 1;
segment0Size = 1;
break;
}
// Read sizes for all segments except the first. Include padding if necessary.
......@@ -159,12 +160,13 @@ InputStreamMessageReader::InputStreamMessageReader(
// Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash.
VALIDATE_INPUT(totalWords <= options.traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnproto::ReaderOptions.") {
KJ_REQUIRE(totalWords <= options.traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see "
"capnproto::ReaderOptions.") {
segmentCount = 1;
segment0Size = std::min<size_t>(segment0Size, options.traversalLimitInWords);
totalWords = segment0Size;
break;
}
if (scratchSpace.size() < totalWords) {
......
......@@ -162,7 +162,9 @@ static void print(std::ostream& os, DynamicValue::Reader value,
break;
}
case DynamicValue::INTERFACE:
FAIL_RECOVERABLE_ASSERT("Don't know how to print interfaces.") {}
KJ_FAIL_ASSERT("Don't know how to print interfaces.") {
break;
}
break;
case DynamicValue::OBJECT:
os << "(opaque object)";
......
......@@ -30,9 +30,11 @@ namespace internal {
void inlineRequireFailure(const char* file, int line, const char* expectation,
const char* macroArgs, const char* message) {
if (message == nullptr) {
Log::fatalFault(file, line, Exception::Nature::PRECONDITION, expectation, macroArgs);
Log::Fault f(file, line, Exception::Nature::PRECONDITION, 0, expectation, macroArgs);
f.fatal();
} else {
Log::fatalFault(file, line, Exception::Nature::PRECONDITION, expectation, macroArgs, message);
Log::Fault f(file, line, Exception::Nature::PRECONDITION, 0, expectation, macroArgs, message);
f.fatal();
}
}
......
......@@ -30,7 +30,7 @@
#ifndef KJ_COMMON_H_
#define KJ_COMMON_H_
#if __cplusplus < 201103L
#if __cplusplus < 201103L && !__CDT_PARSER__
#error "This code requires C++11. Either your compiler does not support it or it is not enabled."
#ifdef __GNUC__
// Compiler claims compatibility with GCC, so presumably supports -std.
......
......@@ -32,9 +32,8 @@ namespace kj {
ArrayPtr<const char> KJ_STRINGIFY(Exception::Nature nature) {
static const char* NATURE_STRINGS[] = {
"precondition not met",
"requirement not met",
"bug in code",
"invalid input data",
"error from OS",
"network failure",
"error"
......@@ -174,7 +173,7 @@ void ExceptionCallback::logMessage(StringPtr text) {
}
void ExceptionCallback::useProcessWide() {
RECOVERABLE_REQUIRE(globalCallback == nullptr,
KJ_REQUIRE(globalCallback == nullptr,
"Can't register multiple global ExceptionCallbacks at once.") {
return;
}
......
......@@ -49,7 +49,6 @@ public:
PRECONDITION,
LOCAL_BUG,
INPUT,
OS_ERROR,
NETWORK_FAILURE,
OTHER
......
......@@ -188,8 +188,9 @@ ArrayPtr<const byte> ArrayInputStream::getReadBuffer() {
size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t n = std::min(maxBytes, array.size());
size_t result = n;
VALIDATE_INPUT(n >= minBytes, "ArrayInputStream ended prematurely.") {
KJ_REQUIRE(n >= minBytes, "ArrayInputStream ended prematurely.") {
result = minBytes; // garbage
break;
}
memcpy(dst, array.begin(), n);
array = array.slice(n, array.size());
......@@ -197,8 +198,9 @@ size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
}
void ArrayInputStream::skip(size_t bytes) {
VALIDATE_INPUT(array.size() >= bytes, "ArrayInputStream ended prematurely.") {
KJ_REQUIRE(array.size() >= bytes, "ArrayInputStream ended prematurely.") {
bytes = array.size();
break;
}
array = array.slice(bytes, array.size());
}
......@@ -228,7 +230,9 @@ void ArrayOutputStream::write(const void* src, size_t size) {
AutoCloseFd::~AutoCloseFd() {
if (fd >= 0 && close(fd) < 0) {
FAIL_RECOVERABLE_SYSCALL("close", errno, fd);
FAIL_SYSCALL("close", errno, fd) {
break;
}
}
}
......@@ -240,8 +244,9 @@ size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
byte* max = pos + maxBytes;
while (pos < min) {
ssize_t n = KJ_SYSCALL(::read(fd, pos, max - pos), fd);
VALIDATE_INPUT(n > 0, "Premature EOF") {
ssize_t n;
KJ_SYSCALL(n = ::read(fd, pos, max - pos), fd);
KJ_REQUIRE(n > 0, "Premature EOF") {
return minBytes;
}
pos += n;
......@@ -256,7 +261,8 @@ void FdOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer);
while (size > 0) {
ssize_t n = KJ_SYSCALL(::write(fd, pos, size), fd);
ssize_t n;
KJ_SYSCALL(n = ::write(fd, pos, size), fd);
KJ_ASSERT(n > 0, "write() returned zero.");
pos += n;
size -= n;
......@@ -280,7 +286,8 @@ void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
}
while (current < iov.end()) {
ssize_t n = KJ_SYSCALL(::writev(fd, current, iov.end() - current), fd);
ssize_t n;
KJ_SYSCALL(n = ::writev(fd, current, iov.end() - current), fd);
KJ_ASSERT(n > 0, "writev() returned zero.");
while (static_cast<size_t>(n) >= current->iov_len) {
......
......@@ -101,18 +101,39 @@ TEST(Logging, Log) {
mockCallback.text);
mockCallback.text.clear();
KJ_DBG("Some debug text."); line = __LINE__;
EXPECT_EQ("log message: debug: " + fileLine(__FILE__, line) + ": Some debug text.\n",
mockCallback.text);
mockCallback.text.clear();
// INFO logging is disabled by default.
KJ_LOG(INFO, "Info."); line = __LINE__;
EXPECT_EQ("", mockCallback.text);
mockCallback.text.clear();
// Enable it.
Log::setLogLevel(Log::Severity::INFO);
KJ_LOG(INFO, "Some text."); line = __LINE__;
EXPECT_EQ("log message: info: " + fileLine(__FILE__, line) + ": Some text.\n",
mockCallback.text);
mockCallback.text.clear();
// Back to default.
Log::setLogLevel(Log::Severity::WARNING);
KJ_ASSERT(1 == 1);
EXPECT_THROW(KJ_ASSERT(1 == 2), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: expected "
"1 == 2\n", mockCallback.text);
mockCallback.text.clear();
RECOVERABLE_ASSERT(1 == 1) {
KJ_ASSERT(1 == 1) {
ADD_FAILURE() << "Shouldn't call recovery code when check passes.";
break;
};
bool recovered = false;
RECOVERABLE_ASSERT(1 == 2, "1 is not 2") { recovered = true; } line = __LINE__;
KJ_ASSERT(1 == 2, "1 is not 2") { recovered = true; break; } line = __LINE__;
EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": bug in code: expected "
"1 == 2; 1 is not 2\n", mockCallback.text);
EXPECT_TRUE(recovered);
......@@ -124,11 +145,11 @@ TEST(Logging, Log) {
mockCallback.text.clear();
EXPECT_THROW(KJ_REQUIRE(1 == 2, i, "hi", str), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": precondition not met: expected "
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": requirement not met: expected "
"1 == 2; i = 123; hi; str = foo\n", mockCallback.text);
mockCallback.text.clear();
EXPECT_THROW(KJ_ASSERT(false, "foo"), MockException); line = __LINE__;
EXPECT_THROW(KJ_FAIL_ASSERT("foo"), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: foo\n",
mockCallback.text);
mockCallback.text.clear();
......@@ -142,7 +163,8 @@ TEST(Logging, Syscall) {
int i = 123;
const char* str = "foo";
int fd = KJ_SYSCALL(dup(STDIN_FILENO));
int fd;
KJ_SYSCALL(fd = dup(STDIN_FILENO));
KJ_SYSCALL(close(fd));
EXPECT_THROW(KJ_SYSCALL(close(fd), i, "bar", str), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): "
......@@ -151,7 +173,7 @@ TEST(Logging, Syscall) {
int result = 0;
bool recovered = false;
RECOVERABLE_SYSCALL(result = close(fd), i, "bar", str) { recovered = true; } line = __LINE__;
KJ_SYSCALL(result = close(fd), i, "bar", str) { recovered = true; break; } line = __LINE__;
EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): "
+ strerror(EBADF) + "; i = 123; bar; str = foo\n", mockCallback.text);
EXPECT_LT(result, 0);
......
......@@ -29,14 +29,15 @@
namespace kj {
Log::Severity Log::minSeverity = Log::Severity::INFO;
Log::Severity Log::minSeverity = Log::Severity::WARNING;
ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity) {
static const char* SEVERITY_STRINGS[] = {
"info",
"warning",
"error",
"fatal"
"fatal",
"debug"
};
const char* s = SEVERITY_STRINGS[static_cast<uint>(severity)];
......@@ -110,6 +111,10 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro
}
}
if (style == ASSERTION && code == nullptr) {
style = LOG;
}
{
StringPtr expected = "expected ";
StringPtr codeArray = style == LOG ? nullptr : StringPtr(code);
......@@ -117,11 +122,6 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro
StringPtr delim = "; ";
StringPtr colon = ": ";
if (style == ASSERTION && strcmp(code, "false") == 0) {
// Don't print "expected false", that's silly.
style = LOG;
}
StringPtr sysErrorArray;
#if __USE_GNU
char buffer[256];
......@@ -194,38 +194,28 @@ void Log::logInternal(const char* file, int line, Severity severity, const char*
makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n'));
}
void Log::recoverableFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onRecoverableException(
Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
Log::Fault::~Fault() noexcept(false) {
if (exception != nullptr) {
Exception copy = mv(*exception);
delete exception;
getExceptionCallback().onRecoverableException(mv(copy));
}
}
void Log::fatalFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onFatalException(
Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
void Log::Fault::fatal() {
Exception copy = mv(*exception);
delete exception;
exception = nullptr;
getExceptionCallback().onFatalException(mv(copy));
abort();
}
void Log::recoverableFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onRecoverableException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
}
void Log::fatalFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onFatalException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
abort();
void Log::Fault::init(
const char* file, int line, Exception::Nature nature, int errorNumber,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
exception = new Exception(nature, Exception::Durability::PERMANENT, file, line,
makeDescription(nature == Exception::Nature::OS_ERROR ? SYSCALL : ASSERTION,
condition, errorNumber, macroArgs, argValues));
}
void Log::addContextToInternal(Exception& exception, const char* file, int line,
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment