Commit db505f42 authored by Kenton Varda's avatar Kenton Varda

Refactor assertion macros, specifically with regards to recoverability.

parent 850af66a
...@@ -89,7 +89,9 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) { ...@@ -89,7 +89,9 @@ SegmentReader* ReaderArena::tryGetSegment(SegmentId id) {
} }
void ReaderArena::reportReadLimitReached() { void ReaderArena::reportReadLimitReached() {
FAIL_VALIDATE_INPUT("Exceeded message traversal limit. See capnproto::ReaderOptions."); KJ_FAIL_REQUIRE("Exceeded message traversal limit. See capnproto::ReaderOptions.") {
return;
}
} }
// ======================================================================================= // =======================================================================================
...@@ -201,8 +203,10 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) { ...@@ -201,8 +203,10 @@ SegmentReader* BuilderArena::tryGetSegment(SegmentId id) {
} }
void BuilderArena::reportReadLimitReached() { void BuilderArena::reportReadLimitReached() {
FAIL_RECOVERABLE_ASSERT( KJ_FAIL_ASSERT(
"Read limit reached for BuilderArena, but it should have been unlimited.") {} "Read limit reached for BuilderArena, but it should have been unlimited.") {
return;
}
} }
} // namespace internal } // namespace internal
......
...@@ -112,9 +112,10 @@ kj::Maybe<EnumSchema::Enumerant> DynamicEnum::getEnumerant() { ...@@ -112,9 +112,10 @@ kj::Maybe<EnumSchema::Enumerant> DynamicEnum::getEnumerant() {
} }
uint16_t DynamicEnum::asImpl(uint64_t requestedTypeId) { uint16_t DynamicEnum::asImpl(uint64_t requestedTypeId) {
RECOVERABLE_REQUIRE(requestedTypeId == schema.getProto().getId(), KJ_REQUIRE(requestedTypeId == schema.getProto().getId(),
"Type mismatch in DynamicEnum.as().") { "Type mismatch in DynamicEnum.as().") {
// use it anyway // use it anyway
break;
} }
return value; return value;
} }
...@@ -125,7 +126,7 @@ DynamicStruct::Reader DynamicObject::as(StructSchema schema) { ...@@ -125,7 +126,7 @@ DynamicStruct::Reader DynamicObject::as(StructSchema schema) {
if (reader.kind == internal::ObjectKind::NULL_POINTER) { if (reader.kind == internal::ObjectKind::NULL_POINTER) {
return DynamicStruct::Reader(schema, internal::StructReader()); return DynamicStruct::Reader(schema, internal::StructReader());
} }
RECOVERABLE_REQUIRE(reader.kind == internal::ObjectKind::STRUCT, "Object is not a struct.") { KJ_REQUIRE(reader.kind == internal::ObjectKind::STRUCT, "Object is not a struct.") {
// Return default struct. // Return default struct.
return DynamicStruct::Reader(schema, internal::StructReader()); return DynamicStruct::Reader(schema, internal::StructReader());
} }
...@@ -136,7 +137,7 @@ DynamicList::Reader DynamicObject::as(ListSchema schema) { ...@@ -136,7 +137,7 @@ DynamicList::Reader DynamicObject::as(ListSchema schema) {
if (reader.kind == internal::ObjectKind::NULL_POINTER) { if (reader.kind == internal::ObjectKind::NULL_POINTER) {
return DynamicList::Reader(schema, internal::ListReader()); return DynamicList::Reader(schema, internal::ListReader());
} }
RECOVERABLE_REQUIRE(reader.kind == internal::ObjectKind::LIST, "Object is not a list.") { KJ_REQUIRE(reader.kind == internal::ObjectKind::LIST, "Object is not a list.") {
// Return empty list. // Return empty list.
return DynamicList::Reader(schema, internal::ListReader()); return DynamicList::Reader(schema, internal::ListReader());
} }
...@@ -880,7 +881,7 @@ void DynamicStruct::Builder::setImpl( ...@@ -880,7 +881,7 @@ void DynamicStruct::Builder::setImpl(
getImpl(builder, member).as<DynamicUnion>().set(member, src.get()); getImpl(builder, member).as<DynamicUnion>().set(member, src.get());
return; return;
} else { } else {
FAIL_RECOVERABLE_REQUIRE( KJ_FAIL_REQUIRE(
"Trying to copy a union value, but the union's discriminant is not recognized. It " "Trying to copy a union value, but the union's discriminant is not recognized. It "
"was probably constructed using a newer version of the schema.") { "was probably constructed using a newer version of the schema.") {
// Just don't copy anything. // Just don't copy anything.
...@@ -928,8 +929,8 @@ void DynamicStruct::Builder::setImpl( ...@@ -928,8 +929,8 @@ void DynamicStruct::Builder::setImpl(
rawValue = enumSchema.getEnumerantByName(value.as<Text>()).getOrdinal(); rawValue = enumSchema.getEnumerantByName(value.as<Text>()).getOrdinal();
} else { } else {
DynamicEnum enumValue = value.as<DynamicEnum>(); DynamicEnum enumValue = value.as<DynamicEnum>();
RECOVERABLE_REQUIRE(enumValue.getSchema() == enumSchema, KJ_REQUIRE(enumValue.getSchema() == enumSchema,
"Type mismatch when using DynamicList::Builder::set().") { "Type mismatch when using DynamicList::Builder::set().") {
return; return;
} }
rawValue = enumValue.getRaw(); rawValue = enumValue.getRaw();
...@@ -967,8 +968,9 @@ void DynamicStruct::Builder::setImpl( ...@@ -967,8 +968,9 @@ void DynamicStruct::Builder::setImpl(
return; return;
} }
FAIL_RECOVERABLE_REQUIRE("can't set field of unknown type", (uint)type.which()); KJ_FAIL_REQUIRE("can't set field of unknown type", (uint)type.which()) {
return; return;
}
} }
} }
...@@ -1109,8 +1111,9 @@ DynamicValue::Reader DynamicList::Reader::operator[](uint index) const { ...@@ -1109,8 +1111,9 @@ DynamicValue::Reader DynamicList::Reader::operator[](uint index) const {
reader.getObjectElement(index * ELEMENTS))); reader.getObjectElement(index * ELEMENTS)));
case schema::Type::Body::INTERFACE_TYPE: case schema::Type::Body::INTERFACE_TYPE:
FAIL_RECOVERABLE_ASSERT("Interfaces not implemented.") {} KJ_FAIL_ASSERT("Interfaces not implemented.") {
return nullptr; return nullptr;
}
} }
return nullptr; return nullptr;
...@@ -1171,15 +1174,16 @@ DynamicValue::Builder DynamicList::Builder::operator[](uint index) const { ...@@ -1171,15 +1174,16 @@ DynamicValue::Builder DynamicList::Builder::operator[](uint index) const {
return nullptr; return nullptr;
case schema::Type::Body::INTERFACE_TYPE: case schema::Type::Body::INTERFACE_TYPE:
FAIL_RECOVERABLE_ASSERT("Interfaces not implemented.") {} KJ_FAIL_ASSERT("Interfaces not implemented.") {
return nullptr; return nullptr;
}
} }
return nullptr; return nullptr;
} }
void DynamicList::Builder::set(uint index, DynamicValue::Reader value) { void DynamicList::Builder::set(uint index, DynamicValue::Reader value) {
RECOVERABLE_REQUIRE(index < size(), "List index out-of-bounds.") { KJ_REQUIRE(index < size(), "List index out-of-bounds.") {
return; return;
} }
...@@ -1219,8 +1223,9 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) { ...@@ -1219,8 +1223,9 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) {
// Not supported for the same reason List<struct> doesn't support it -- the space for the // Not supported for the same reason List<struct> doesn't support it -- the space for the
// element is already allocated, and if it's smaller than the input value the copy would // element is already allocated, and if it's smaller than the input value the copy would
// have to be lossy. // have to be lossy.
FAIL_RECOVERABLE_ASSERT("DynamicList of structs does not support set()."); KJ_FAIL_ASSERT("DynamicList of structs does not support set().") {
return; return;
}
case schema::Type::Body::ENUM_TYPE: { case schema::Type::Body::ENUM_TYPE: {
uint16_t rawValue; uint16_t rawValue;
...@@ -1229,8 +1234,8 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) { ...@@ -1229,8 +1234,8 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) {
rawValue = schema.getEnumElementType().getEnumerantByName(value.as<Text>()).getOrdinal(); rawValue = schema.getEnumElementType().getEnumerantByName(value.as<Text>()).getOrdinal();
} else { } else {
DynamicEnum enumValue = value.as<DynamicEnum>(); DynamicEnum enumValue = value.as<DynamicEnum>();
RECOVERABLE_REQUIRE(schema.getEnumElementType() == enumValue.getSchema(), KJ_REQUIRE(schema.getEnumElementType() == enumValue.getSchema(),
"Type mismatch when using DynamicList::Builder::set().") { "Type mismatch when using DynamicList::Builder::set().") {
return; return;
} }
rawValue = enumValue.getRaw(); rawValue = enumValue.getRaw();
...@@ -1240,15 +1245,19 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) { ...@@ -1240,15 +1245,19 @@ void DynamicList::Builder::set(uint index, DynamicValue::Reader value) {
} }
case schema::Type::Body::OBJECT_TYPE: case schema::Type::Body::OBJECT_TYPE:
FAIL_RECOVERABLE_ASSERT("List(Object) not supported."); KJ_FAIL_ASSERT("List(Object) not supported.") {
return; return;
}
case schema::Type::Body::INTERFACE_TYPE: case schema::Type::Body::INTERFACE_TYPE:
FAIL_RECOVERABLE_ASSERT("Interfaces not implemented.") {} KJ_FAIL_ASSERT("Interfaces not implemented.") {
return; return;
}
} }
FAIL_RECOVERABLE_REQUIRE("can't set element of unknown type", (uint)schema.whichElementType()); KJ_FAIL_REQUIRE("can't set element of unknown type", (uint)schema.whichElementType()) {
return;
}
} }
DynamicValue::Builder DynamicList::Builder::init(uint index, uint size) { DynamicValue::Builder DynamicList::Builder::init(uint index, uint size) {
...@@ -1343,42 +1352,46 @@ namespace { ...@@ -1343,42 +1352,46 @@ namespace {
template <typename T> template <typename T>
T signedToUnsigned(long long value) { T signedToUnsigned(long long value) {
RECOVERABLE_REQUIRE(value >= 0 && T(value) == value, KJ_REQUIRE(value >= 0 && T(value) == value, "Value out-of-range for requested type.", value) {
"Value out-of-range for requested type.", value) {
// Use it anyway. // Use it anyway.
break;
} }
return value; return value;
} }
template <> template <>
uint64_t signedToUnsigned<uint64_t>(long long value) { uint64_t signedToUnsigned<uint64_t>(long long value) {
RECOVERABLE_REQUIRE(value >= 0, "Value out-of-range for requested type.", value) { KJ_REQUIRE(value >= 0, "Value out-of-range for requested type.", value) {
// Use it anyway. // Use it anyway.
break;
} }
return value; return value;
} }
template <typename T> template <typename T>
T unsignedToSigned(unsigned long long value) { T unsignedToSigned(unsigned long long value) {
RECOVERABLE_REQUIRE(T(value) >= 0 && (unsigned long long)T(value) == value, KJ_REQUIRE(T(value) >= 0 && (unsigned long long)T(value) == value,
"Value out-of-range for requested type.", value) { "Value out-of-range for requested type.", value) {
// Use it anyway. // Use it anyway.
break;
} }
return value; return value;
} }
template <> template <>
int64_t unsignedToSigned<int64_t>(unsigned long long value) { int64_t unsignedToSigned<int64_t>(unsigned long long value) {
RECOVERABLE_REQUIRE(int64_t(value) >= 0, "Value out-of-range for requested type.", value) { KJ_REQUIRE(int64_t(value) >= 0, "Value out-of-range for requested type.", value) {
// Use it anyway. // Use it anyway.
break;
} }
return value; return value;
} }
template <typename T, typename U> template <typename T, typename U>
T checkRoundTrip(U value) { T checkRoundTrip(U value) {
RECOVERABLE_REQUIRE(T(value) == value, "Value out-of-range for requested type.", value) { KJ_REQUIRE(T(value) == value, "Value out-of-range for requested type.", value) {
// Use it anyway. // Use it anyway.
break;
} }
return value; return value;
} }
...@@ -1395,10 +1408,9 @@ typeName DynamicValue::Reader::AsImpl<typeName>::apply(Reader reader) { \ ...@@ -1395,10 +1408,9 @@ typeName DynamicValue::Reader::AsImpl<typeName>::apply(Reader reader) { \
case FLOAT: \ case FLOAT: \
return ifFloat<typeName>(reader.floatValue); \ return ifFloat<typeName>(reader.floatValue); \
default: \ default: \
FAIL_RECOVERABLE_REQUIRE("Type mismatch when using DynamicValue::Reader::as().") { \ KJ_FAIL_REQUIRE("Type mismatch when using DynamicValue::Reader::as().") { \
/* use zero */ \ return 0; \
} \ } \
return 0; \
} \ } \
} \ } \
typeName DynamicValue::Builder::AsImpl<typeName>::apply(Builder builder) { \ typeName DynamicValue::Builder::AsImpl<typeName>::apply(Builder builder) { \
...@@ -1410,10 +1422,9 @@ typeName DynamicValue::Builder::AsImpl<typeName>::apply(Builder builder) { \ ...@@ -1410,10 +1422,9 @@ typeName DynamicValue::Builder::AsImpl<typeName>::apply(Builder builder) { \
case FLOAT: \ case FLOAT: \
return ifFloat<typeName>(builder.floatValue); \ return ifFloat<typeName>(builder.floatValue); \
default: \ default: \
FAIL_RECOVERABLE_REQUIRE("Type mismatch when using DynamicValue::Builder::as().") { \ KJ_FAIL_REQUIRE("Type mismatch when using DynamicValue::Builder::as().") { \
/* use zero */ \ return 0; \
} \ } \
return 0; \
} \ } \
} }
...@@ -1459,8 +1470,7 @@ Data::Reader DynamicValue::Reader::AsImpl<Data>::apply(Reader reader) { ...@@ -1459,8 +1470,7 @@ Data::Reader DynamicValue::Reader::AsImpl<Data>::apply(Reader reader) {
// Implicitly convert from text. // Implicitly convert from text.
return reader.textValue; return reader.textValue;
} }
RECOVERABLE_REQUIRE(reader.type == DATA, KJ_REQUIRE(reader.type == DATA, "Type mismatch when using DynamicValue::Reader::as().") {
"Type mismatch when using DynamicValue::Reader::as().") {
return Data::Reader(); return Data::Reader();
} }
return reader.dataValue; return reader.dataValue;
...@@ -1470,8 +1480,7 @@ Data::Builder DynamicValue::Builder::AsImpl<Data>::apply(Builder builder) { ...@@ -1470,8 +1480,7 @@ Data::Builder DynamicValue::Builder::AsImpl<Data>::apply(Builder builder) {
// Implicitly convert from text. // Implicitly convert from text.
return builder.textValue; return builder.textValue;
} }
RECOVERABLE_REQUIRE(builder.type == DATA, KJ_REQUIRE(builder.type == DATA, "Type mismatch when using DynamicValue::Builder::as().") {
"Type mismatch when using DynamicValue::Builder::as().") {
return Data::Builder(); return Data::Builder();
} }
return builder.dataValue; return builder.dataValue;
...@@ -1479,15 +1488,13 @@ Data::Builder DynamicValue::Builder::AsImpl<Data>::apply(Builder builder) { ...@@ -1479,15 +1488,13 @@ Data::Builder DynamicValue::Builder::AsImpl<Data>::apply(Builder builder) {
// As in the header, HANDLE_TYPE(void, VOID, Void) crashes GCC 4.7. // As in the header, HANDLE_TYPE(void, VOID, Void) crashes GCC 4.7.
Void DynamicValue::Reader::AsImpl<Void>::apply(Reader reader) { Void DynamicValue::Reader::AsImpl<Void>::apply(Reader reader) {
RECOVERABLE_REQUIRE(reader.type == VOID, KJ_REQUIRE(reader.type == VOID, "Type mismatch when using DynamicValue::Reader::as().") {
"Type mismatch when using DynamicValue::Reader::as().") {
return Void(); return Void();
} }
return reader.voidValue; return reader.voidValue;
} }
Void DynamicValue::Builder::AsImpl<Void>::apply(Builder builder) { Void DynamicValue::Builder::AsImpl<Void>::apply(Builder builder) {
RECOVERABLE_REQUIRE(builder.type == VOID, KJ_REQUIRE(builder.type == VOID, "Type mismatch when using DynamicValue::Builder::as().") {
"Type mismatch when using DynamicValue::Builder::as().") {
return Void(); return Void();
} }
return builder.voidValue; return builder.voidValue;
......
...@@ -273,15 +273,15 @@ struct WireHelpers { ...@@ -273,15 +273,15 @@ struct WireHelpers {
if (segment != nullptr && ref->kind() == WirePointer::FAR) { if (segment != nullptr && ref->kind() == WirePointer::FAR) {
// Look up the segment containing the landing pad. // Look up the segment containing the landing pad.
segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(ref->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, "Message contains far pointer to unknown segment.") { KJ_REQUIRE(segment != nullptr, "Message contains far pointer to unknown segment.") {
return nullptr; return nullptr;
} }
// Find the landing pad and check that it is within bounds. // Find the landing pad and check that it is within bounds.
const word* ptr = segment->getStartPtr() + ref->farPositionInSegment(); const word* ptr = segment->getStartPtr() + ref->farPositionInSegment();
WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS; WordCount padWords = (1 + ref->isDoubleFar()) * POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + padWords), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + padWords),
"Message contains out-of-bounds far pointer.") { "Message contains out-of-bounds far pointer.") {
return nullptr; return nullptr;
} }
...@@ -298,8 +298,7 @@ struct WireHelpers { ...@@ -298,8 +298,7 @@ struct WireHelpers {
ref = pad + 1; ref = pad + 1;
segment = segment->getArena()->tryGetSegment(pad->farRef.segmentId.get()); segment = segment->getArena()->tryGetSegment(pad->farRef.segmentId.get());
VALIDATE_INPUT(segment != nullptr, KJ_REQUIRE(segment != nullptr, "Message contains double-far pointer to unknown segment.") {
"Message contains double-far pointer to unknown segment.") {
return nullptr; return nullptr;
} }
...@@ -338,7 +337,9 @@ struct WireHelpers { ...@@ -338,7 +337,9 @@ struct WireHelpers {
break; break;
} }
case WirePointer::RESERVED_3: case WirePointer::RESERVED_3:
FAIL_RECOVERABLE_ASSERT("Don't know how to handle RESERVED_3.") {} KJ_FAIL_ASSERT("Don't know how to handle RESERVED_3.") {
break;
}
break; break;
} }
} }
...@@ -404,10 +405,14 @@ struct WireHelpers { ...@@ -404,10 +405,14 @@ struct WireHelpers {
break; break;
} }
case WirePointer::FAR: case WirePointer::FAR:
FAIL_RECOVERABLE_ASSERT("Unexpected FAR pointer.") {} KJ_FAIL_ASSERT("Unexpected FAR pointer.") {
break;
}
break; break;
case WirePointer::RESERVED_3: case WirePointer::RESERVED_3:
FAIL_RECOVERABLE_ASSERT("Don't know how to handle RESERVED_3.") {} KJ_FAIL_ASSERT("Don't know how to handle RESERVED_3.") {
break;
}
break; break;
} }
} }
...@@ -435,7 +440,7 @@ struct WireHelpers { ...@@ -435,7 +440,7 @@ struct WireHelpers {
return 0 * WORDS; return 0 * WORDS;
} }
VALIDATE_INPUT(nestingLimit > 0, "Message is too deeply-nested.") { KJ_REQUIRE(nestingLimit > 0, "Message is too deeply-nested.") {
return 0 * WORDS; return 0 * WORDS;
} }
--nestingLimit; --nestingLimit;
...@@ -446,9 +451,9 @@ struct WireHelpers { ...@@ -446,9 +451,9 @@ struct WireHelpers {
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: { case WirePointer::STRUCT: {
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
break; return result;
} }
result += ref->structRef.wordSize(); result += ref->structRef.wordSize();
...@@ -473,9 +478,9 @@ struct WireHelpers { ...@@ -473,9 +478,9 @@ struct WireHelpers {
WordCount totalWords = roundUpToWords( WordCount totalWords = roundUpToWords(
ElementCount64(ref->listRef.elementCount()) * ElementCount64(ref->listRef.elementCount()) *
dataBitsPerElement(ref->listRef.elementSize())); dataBitsPerElement(ref->listRef.elementSize()));
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + totalWords), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + totalWords),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
break; return result;
} }
result += totalWords; result += totalWords;
break; break;
...@@ -483,9 +488,9 @@ struct WireHelpers { ...@@ -483,9 +488,9 @@ struct WireHelpers {
case FieldSize::POINTER: { case FieldSize::POINTER: {
WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS); WirePointerCount count = ref->listRef.elementCount() * (POINTERS / ELEMENTS);
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + count * WORDS_PER_POINTER),
"Message contained out-of-bounds list pointer.") { "Message contained out-of-bounds list pointer.") {
break; return result;
} }
result += count * WORDS_PER_POINTER; result += count * WORDS_PER_POINTER;
...@@ -498,10 +503,9 @@ struct WireHelpers { ...@@ -498,10 +503,9 @@ struct WireHelpers {
} }
case FieldSize::INLINE_COMPOSITE: { case FieldSize::INLINE_COMPOSITE: {
WordCount wordCount = ref->listRef.inlineCompositeWordCount(); WordCount wordCount = ref->listRef.inlineCompositeWordCount();
VALIDATE_INPUT( KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS),
boundsCheck(segment, ptr, ptr + wordCount + POINTER_SIZE_IN_WORDS), "Message contained out-of-bounds list pointer.") {
"Message contained out-of-bounds list pointer.") { return result;
break;
} }
result += wordCount + POINTER_SIZE_IN_WORDS; result += wordCount + POINTER_SIZE_IN_WORDS;
...@@ -509,14 +513,14 @@ struct WireHelpers { ...@@ -509,14 +513,14 @@ struct WireHelpers {
const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* elementTag = reinterpret_cast<const WirePointer*>(ptr);
ElementCount count = elementTag->inlineCompositeListElementCount(); ElementCount count = elementTag->inlineCompositeListElementCount();
VALIDATE_INPUT(elementTag->kind() == WirePointer::STRUCT, KJ_REQUIRE(elementTag->kind() == WirePointer::STRUCT,
"Don't know how to handle non-STRUCT inline composite.") { "Don't know how to handle non-STRUCT inline composite.") {
break; return result;
} }
VALIDATE_INPUT(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount, KJ_REQUIRE(elementTag->structRef.wordSize() / ELEMENTS * count <= wordCount,
"Struct list pointer's elements overran size.") { "Struct list pointer's elements overran size.") {
break; return result;
} }
WordCount dataSize = elementTag->structRef.dataSize.get(); WordCount dataSize = elementTag->structRef.dataSize.get();
...@@ -538,13 +542,13 @@ struct WireHelpers { ...@@ -538,13 +542,13 @@ struct WireHelpers {
break; break;
} }
case WirePointer::FAR: case WirePointer::FAR:
FAIL_RECOVERABLE_ASSERT("Unexpected FAR pointer.") { KJ_FAIL_ASSERT("Unexpected FAR pointer.") {
break; break;
} }
break; break;
case WirePointer::RESERVED_3: case WirePointer::RESERVED_3:
FAIL_VALIDATE_INPUT("Don't know how to handle RESERVED_3.") { KJ_FAIL_REQUIRE("Don't know how to handle RESERVED_3.") {
break; return result;
} }
break; break;
} }
...@@ -743,7 +747,7 @@ struct WireHelpers { ...@@ -743,7 +747,7 @@ struct WireHelpers {
SegmentBuilder* oldSegment = segment; SegmentBuilder* oldSegment = segment;
word* oldPtr = followFars(oldRef, oldSegment); word* oldPtr = followFars(oldRef, oldSegment);
VALIDATE_INPUT(oldRef->kind() == WirePointer::STRUCT, KJ_REQUIRE(oldRef->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") { "Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault; goto useDefault;
} }
...@@ -871,7 +875,7 @@ struct WireHelpers { ...@@ -871,7 +875,7 @@ struct WireHelpers {
SegmentBuilder* segment = origSegment; SegmentBuilder* segment = origSegment;
word* ptr = followFars(ref, segment); word* ptr = followFars(ref, segment);
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") { "Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault; goto useDefault;
} }
...@@ -904,13 +908,17 @@ struct WireHelpers { ...@@ -904,13 +908,17 @@ struct WireHelpers {
case FieldSize::TWO_BYTES: case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES: case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(dataSize >= 1 * WORDS, KJ_REQUIRE(dataSize >= 1 * WORDS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.") {
goto useDefault;
}
break; break;
case FieldSize::POINTER: case FieldSize::POINTER:
VALIDATE_INPUT(pointerCount >= 1 * POINTERS, KJ_REQUIRE(pointerCount >= 1 * POINTERS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.") {
goto useDefault;
}
// Adjust the pointer to point at the reference segment. // Adjust the pointer to point at the reference segment.
ptr += dataSize; ptr += dataSize;
break; break;
...@@ -930,10 +938,14 @@ struct WireHelpers { ...@@ -930,10 +938,14 @@ struct WireHelpers {
BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS; BitCount dataSize = dataBitsPerElement(oldSize) * ELEMENTS;
WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS; WirePointerCount pointerCount = pointersPerElement(oldSize) * ELEMENTS;
VALIDATE_INPUT(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS, KJ_REQUIRE(dataSize >= dataBitsPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type."); "Existing list value is incompatible with expected type.") {
VALIDATE_INPUT(pointerCount >= pointersPerElement(elementSize) * ELEMENTS, goto useDefault;
"Existing list value is incompatible with expected type."); }
KJ_REQUIRE(pointerCount >= pointersPerElement(elementSize) * ELEMENTS,
"Existing list value is incompatible with expected type.") {
goto useDefault;
}
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
return ListBuilder(segment, ptr, step, ref->listRef.elementCount(), return ListBuilder(segment, ptr, step, ref->listRef.elementCount(),
...@@ -960,8 +972,8 @@ struct WireHelpers { ...@@ -960,8 +972,8 @@ struct WireHelpers {
SegmentBuilder* oldSegment = origSegment; SegmentBuilder* oldSegment = origSegment;
word* oldPtr = followFars(oldRef, oldSegment); word* oldPtr = followFars(oldRef, oldSegment);
VALIDATE_INPUT(oldRef->kind() == WirePointer::LIST, KJ_REQUIRE(oldRef->kind() == WirePointer::LIST,
"Called getList{Field,Element}() but existing pointer is not a list.") { "Called getList{Field,Element}() but existing pointer is not a list.") {
goto useDefault; goto useDefault;
} }
...@@ -972,8 +984,8 @@ struct WireHelpers { ...@@ -972,8 +984,8 @@ struct WireHelpers {
WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr); WirePointer* oldTag = reinterpret_cast<WirePointer*>(oldPtr);
oldPtr += POINTER_SIZE_IN_WORDS; oldPtr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(oldTag->kind() == WirePointer::STRUCT, KJ_REQUIRE(oldTag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE list with non-STRUCT elements not supported.") { "INLINE_COMPOSITE list with non-STRUCT elements not supported.") {
goto useDefault; goto useDefault;
} }
...@@ -1046,8 +1058,8 @@ struct WireHelpers { ...@@ -1046,8 +1058,8 @@ struct WireHelpers {
// No expectations. // No expectations.
break; break;
case FieldSize::POINTER: case FieldSize::POINTER:
VALIDATE_INPUT(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID, KJ_REQUIRE(oldSize == FieldSize::POINTER || oldSize == FieldSize::VOID,
"Struct list has incompatible element size.") { "Struct list has incompatible element size.") {
goto useDefault; goto useDefault;
} }
break; break;
...@@ -1060,8 +1072,8 @@ struct WireHelpers { ...@@ -1060,8 +1072,8 @@ struct WireHelpers {
case FieldSize::FOUR_BYTES: case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: case FieldSize::EIGHT_BYTES:
// Preferred size is data-only. // Preferred size is data-only.
VALIDATE_INPUT(oldSize != FieldSize::POINTER, KJ_REQUIRE(oldSize != FieldSize::POINTER,
"Struct list has incompatible element size.") { "Struct list has incompatible element size.") {
goto useDefault; goto useDefault;
} }
break; break;
...@@ -1433,8 +1445,8 @@ struct WireHelpers { ...@@ -1433,8 +1445,8 @@ struct WireHelpers {
defaultValue = nullptr; // If the default value is itself invalid, don't use it again. defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} }
VALIDATE_INPUT(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
...@@ -1444,13 +1456,13 @@ struct WireHelpers { ...@@ -1444,13 +1456,13 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(ref->kind() == WirePointer::STRUCT, KJ_REQUIRE(ref->kind() == WirePointer::STRUCT,
"Message contains non-struct pointer where struct pointer was expected.") { "Message contains non-struct pointer where struct pointer was expected.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1475,8 +1487,8 @@ struct WireHelpers { ...@@ -1475,8 +1487,8 @@ struct WireHelpers {
defaultValue = nullptr; // If the default value is itself invalid, don't use it again. defaultValue = nullptr; // If the default value is itself invalid, don't use it again.
} }
VALIDATE_INPUT(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
...@@ -1486,8 +1498,8 @@ struct WireHelpers { ...@@ -1486,8 +1498,8 @@ struct WireHelpers {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where list pointer was expected.") { "Message contains non-list pointer where list pointer was expected.") {
goto useDefault; goto useDefault;
} }
...@@ -1501,21 +1513,21 @@ struct WireHelpers { ...@@ -1501,21 +1513,21 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT, KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
} }
size = tag->inlineCompositeListElementCount(); size = tag->inlineCompositeListElementCount();
wordsPerElement = tag->structRef.wordSize() / ELEMENTS; wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
VALIDATE_INPUT(size * wordsPerElement <= wordCount, KJ_REQUIRE(size * wordsPerElement <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count.") { "INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault; goto useDefault;
} }
...@@ -1530,7 +1542,7 @@ struct WireHelpers { ...@@ -1530,7 +1542,7 @@ struct WireHelpers {
break; break;
case FieldSize::BIT: case FieldSize::BIT:
FAIL_VALIDATE_INPUT("Expected a bit list, but got a list of structs.") { KJ_FAIL_REQUIRE("Expected a bit list, but got a list of structs.") {
goto useDefault; goto useDefault;
} }
break; break;
...@@ -1539,8 +1551,8 @@ struct WireHelpers { ...@@ -1539,8 +1551,8 @@ struct WireHelpers {
case FieldSize::TWO_BYTES: case FieldSize::TWO_BYTES:
case FieldSize::FOUR_BYTES: case FieldSize::FOUR_BYTES:
case FieldSize::EIGHT_BYTES: case FieldSize::EIGHT_BYTES:
VALIDATE_INPUT(tag->structRef.dataSize.get() > 0 * WORDS, KJ_REQUIRE(tag->structRef.dataSize.get() > 0 * WORDS,
"Expected a primitive list, but got a list of pointer-only structs.") { "Expected a primitive list, but got a list of pointer-only structs.") {
goto useDefault; goto useDefault;
} }
break; break;
...@@ -1550,8 +1562,8 @@ struct WireHelpers { ...@@ -1550,8 +1562,8 @@ struct WireHelpers {
// in the struct is the pointer we were looking for, we want to munge the pointer to // in the struct is the pointer we were looking for, we want to munge the pointer to
// point at the first element's pointer segment. // point at the first element's pointer segment.
ptr += tag->structRef.dataSize.get(); ptr += tag->structRef.dataSize.get();
VALIDATE_INPUT(tag->structRef.ptrCount.get() > 0 * POINTERS, KJ_REQUIRE(tag->structRef.ptrCount.get() > 0 * POINTERS,
"Expected a pointer list, but got a list of data-only structs.") { "Expected a pointer list, but got a list of data-only structs.") {
goto useDefault; goto useDefault;
} }
break; break;
...@@ -1573,9 +1585,9 @@ struct WireHelpers { ...@@ -1573,9 +1585,9 @@ struct WireHelpers {
pointersPerElement(ref->listRef.elementSize()) * ELEMENTS; pointersPerElement(ref->listRef.elementSize()) * ELEMENTS;
auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS; auto step = (dataSize + pointerCount * BITS_PER_POINTER) / ELEMENTS;
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + KJ_REQUIRE(boundsCheck(segment, ptr, ptr +
roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)), roundUpToWords(ElementCount64(ref->listRef.elementCount()) * step)),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1589,12 +1601,12 @@ struct WireHelpers { ...@@ -1589,12 +1601,12 @@ struct WireHelpers {
WirePointerCount expectedPointersPerElement = WirePointerCount expectedPointersPerElement =
pointersPerElement(expectedElementSize) * ELEMENTS; pointersPerElement(expectedElementSize) * ELEMENTS;
VALIDATE_INPUT(expectedDataBitsPerElement <= dataSize, KJ_REQUIRE(expectedDataBitsPerElement <= dataSize,
"Message contained list with incompatible element type.") { "Message contained list with incompatible element type.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(expectedPointersPerElement <= pointerCount, KJ_REQUIRE(expectedPointersPerElement <= pointerCount,
"Message contained list with incompatible element type.") { "Message contained list with incompatible element type.") {
goto useDefault; goto useDefault;
} }
...@@ -1620,30 +1632,30 @@ struct WireHelpers { ...@@ -1620,30 +1632,30 @@ struct WireHelpers {
uint size = ref->listRef.elementCount() / ELEMENTS; uint size = ref->listRef.elementCount() / ELEMENTS;
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where text was expected.") { "Message contains non-list pointer where text was expected.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(ref->listRef.elementSize() == FieldSize::BYTE, KJ_REQUIRE(ref->listRef.elementSize() == FieldSize::BYTE,
"Message contains list pointer of non-bytes where text was expected.") { "Message contains list pointer of non-bytes where text was expected.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + KJ_REQUIRE(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds text pointer.") { "Message contained out-of-bounds text pointer.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") {
goto useDefault; goto useDefault;
} }
const char* cptr = reinterpret_cast<const char*>(ptr); const char* cptr = reinterpret_cast<const char*>(ptr);
--size; // NUL terminator --size; // NUL terminator
VALIDATE_INPUT(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") {
goto useDefault; goto useDefault;
} }
...@@ -1667,19 +1679,19 @@ struct WireHelpers { ...@@ -1667,19 +1679,19 @@ struct WireHelpers {
uint size = ref->listRef.elementCount() / ELEMENTS; uint size = ref->listRef.elementCount() / ELEMENTS;
VALIDATE_INPUT(ref->kind() == WirePointer::LIST, KJ_REQUIRE(ref->kind() == WirePointer::LIST,
"Message contains non-list pointer where data was expected.") { "Message contains non-list pointer where data was expected.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(ref->listRef.elementSize() == FieldSize::BYTE, KJ_REQUIRE(ref->listRef.elementSize() == FieldSize::BYTE,
"Message contains list pointer of non-bytes where data was expected.") { "Message contains list pointer of non-bytes where data was expected.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + KJ_REQUIRE(boundsCheck(segment, ptr, ptr +
roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))), roundUpToWords(ref->listRef.elementCount() * (1 * BYTES / ELEMENTS))),
"Message contained out-of-bounds data pointer.") { "Message contained out-of-bounds data pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1716,13 +1728,13 @@ struct WireHelpers { ...@@ -1716,13 +1728,13 @@ struct WireHelpers {
switch (ref->kind()) { switch (ref->kind()) {
case WirePointer::STRUCT: case WirePointer::STRUCT:
VALIDATE_INPUT(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + ref->structRef.wordSize()),
"Message contained out-of-bounds struct pointer.") { "Message contained out-of-bounds struct pointer.") {
goto useDefault; goto useDefault;
} }
return ObjectReader( return ObjectReader(
...@@ -1734,7 +1746,7 @@ struct WireHelpers { ...@@ -1734,7 +1746,7 @@ struct WireHelpers {
case WirePointer::LIST: { case WirePointer::LIST: {
FieldSize elementSize = ref->listRef.elementSize(); FieldSize elementSize = ref->listRef.elementSize();
VALIDATE_INPUT(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
goto useDefault; goto useDefault;
} }
...@@ -1744,21 +1756,23 @@ struct WireHelpers { ...@@ -1744,21 +1756,23 @@ struct WireHelpers {
const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr); const WirePointer* tag = reinterpret_cast<const WirePointer*>(ptr);
ptr += POINTER_SIZE_IN_WORDS; ptr += POINTER_SIZE_IN_WORDS;
VALIDATE_INPUT(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr - POINTER_SIZE_IN_WORDS, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
VALIDATE_INPUT(tag->kind() == WirePointer::STRUCT, KJ_REQUIRE(tag->kind() == WirePointer::STRUCT,
"INLINE_COMPOSITE lists of non-STRUCT type are not supported.") { "INLINE_COMPOSITE lists of non-STRUCT type are not supported.") {
goto useDefault; goto useDefault;
} }
ElementCount elementCount = tag->inlineCompositeListElementCount(); ElementCount elementCount = tag->inlineCompositeListElementCount();
auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS; auto wordsPerElement = tag->structRef.wordSize() / ELEMENTS;
VALIDATE_INPUT(wordsPerElement * elementCount <= wordCount, KJ_REQUIRE(wordsPerElement * elementCount <= wordCount,
"INLINE_COMPOSITE list's elements overrun its word count."); "INLINE_COMPOSITE list's elements overrun its word count.") {
goto useDefault;
}
return ObjectReader( return ObjectReader(
ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD, ListReader(segment, ptr, elementCount, wordsPerElement * BITS_PER_WORD,
...@@ -1771,8 +1785,8 @@ struct WireHelpers { ...@@ -1771,8 +1785,8 @@ struct WireHelpers {
ElementCount elementCount = ref->listRef.elementCount(); ElementCount elementCount = ref->listRef.elementCount();
WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step); WordCount wordCount = roundUpToWords(ElementCount64(elementCount) * step);
VALIDATE_INPUT(boundsCheck(segment, ptr, ptr + wordCount), KJ_REQUIRE(boundsCheck(segment, ptr, ptr + wordCount),
"Message contains out-of-bounds list pointer.") { "Message contains out-of-bounds list pointer.") {
goto useDefault; goto useDefault;
} }
...@@ -1782,8 +1796,9 @@ struct WireHelpers { ...@@ -1782,8 +1796,9 @@ struct WireHelpers {
} }
} }
default: default:
FAIL_VALIDATE_INPUT("Message contained invalid pointer.") {} KJ_FAIL_REQUIRE("Message contained invalid pointer.") {
goto useDefault; goto useDefault;
}
} }
} }
}; };
...@@ -1909,8 +1924,8 @@ StructReader StructReader::readRootUnchecked(const word* location) { ...@@ -1909,8 +1924,8 @@ StructReader StructReader::readRootUnchecked(const word* location) {
StructReader StructReader::readRoot( StructReader StructReader::readRoot(
const word* location, SegmentReader* segment, int nestingLimit) { const word* location, SegmentReader* segment, int nestingLimit) {
VALIDATE_INPUT(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS), KJ_REQUIRE(WireHelpers::boundsCheck(segment, location, location + POINTER_SIZE_IN_WORDS),
"Root location out-of-bounds.") { "Root location out-of-bounds.") {
location = nullptr; location = nullptr;
} }
...@@ -1980,21 +1995,21 @@ WordCount64 StructReader::totalSize() const { ...@@ -1980,21 +1995,21 @@ WordCount64 StructReader::totalSize() const {
// ListBuilder // ListBuilder
Text::Builder ListBuilder::asText() { Text::Builder ListBuilder::asText() {
VALIDATE_INPUT(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Text::Builder(); return Text::Builder();
} }
size_t size = elementCount / ELEMENTS; size_t size = elementCount / ELEMENTS;
VALIDATE_INPUT(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") {
return Text::Builder(); return Text::Builder();
} }
char* cptr = reinterpret_cast<char*>(ptr); char* cptr = reinterpret_cast<char*>(ptr);
--size; // NUL terminator --size; // NUL terminator
VALIDATE_INPUT(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") {
return Text::Builder(); return Text::Builder();
} }
...@@ -2002,8 +2017,8 @@ Text::Builder ListBuilder::asText() { ...@@ -2002,8 +2017,8 @@ Text::Builder ListBuilder::asText() {
} }
Data::Builder ListBuilder::asData() { Data::Builder ListBuilder::asData() {
VALIDATE_INPUT(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Data::Builder(); return Data::Builder();
} }
...@@ -2101,21 +2116,21 @@ ListReader ListBuilder::asReader() const { ...@@ -2101,21 +2116,21 @@ ListReader ListBuilder::asReader() const {
// ListReader // ListReader
Text::Reader ListReader::asText() { Text::Reader ListReader::asText() {
VALIDATE_INPUT(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Text::Reader(); return Text::Reader();
} }
size_t size = elementCount / ELEMENTS; size_t size = elementCount / ELEMENTS;
VALIDATE_INPUT(size > 0, "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(size > 0, "Message contains text that is not NUL-terminated.") {
return Text::Reader(); return Text::Reader();
} }
const char* cptr = reinterpret_cast<const char*>(ptr); const char* cptr = reinterpret_cast<const char*>(ptr);
--size; // NUL terminator --size; // NUL terminator
VALIDATE_INPUT(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") { KJ_REQUIRE(cptr[size] == '\0', "Message contains text that is not NUL-terminated.") {
return Text::Reader(); return Text::Reader();
} }
...@@ -2123,8 +2138,8 @@ Text::Reader ListReader::asText() { ...@@ -2123,8 +2138,8 @@ Text::Reader ListReader::asText() {
} }
Data::Reader ListReader::asData() { Data::Reader ListReader::asData() {
VALIDATE_INPUT(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS, KJ_REQUIRE(structDataSize == 8 * BITS && structPointerCount == 0 * POINTERS,
"Expected Text, got list of non-bytes.") { "Expected Text, got list of non-bytes.") {
return Data::Reader(); return Data::Reader();
} }
...@@ -2132,8 +2147,8 @@ Data::Reader ListReader::asData() { ...@@ -2132,8 +2147,8 @@ Data::Reader ListReader::asData() {
} }
StructReader ListReader::getStructElement(ElementCount index) const { StructReader ListReader::getStructElement(ElementCount index) const {
VALIDATE_INPUT(nestingLimit > 0, KJ_REQUIRE(nestingLimit > 0,
"Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") { "Message is too deeply-nested or contains cycles. See capnproto::ReadOptions.") {
return StructReader(); return StructReader();
} }
......
...@@ -50,9 +50,9 @@ internal::StructReader MessageReader::getRootInternal() { ...@@ -50,9 +50,9 @@ internal::StructReader MessageReader::getRootInternal() {
} }
internal::SegmentReader* segment = arena()->tryGetSegment(internal::SegmentId(0)); internal::SegmentReader* segment = arena()->tryGetSegment(internal::SegmentId(0));
VALIDATE_INPUT(segment != nullptr && KJ_REQUIRE(segment != nullptr &&
segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1), segment->containsInterval(segment->getStartPtr(), segment->getStartPtr() + 1),
"Message did not contain a root pointer.") { "Message did not contain a root pointer.") {
return internal::StructReader(); return internal::StructReader();
} }
......
...@@ -145,9 +145,9 @@ private: ...@@ -145,9 +145,9 @@ private:
std::map<std::pair<uint, Text::Reader>, uint> members; std::map<std::pair<uint, Text::Reader>, uint> members;
#define VALIDATE_SCHEMA(condition, ...) \ #define VALIDATE_SCHEMA(condition, ...) \
VALIDATE_INPUT(condition, ##__VA_ARGS__) { isValid = false; return; } KJ_REQUIRE(condition, ##__VA_ARGS__) { isValid = false; return; }
#define FAIL_VALIDATE_SCHEMA(...) \ #define FAIL_VALIDATE_SCHEMA(...) \
FAIL_VALIDATE_INPUT(__VA_ARGS__) { isValid = false; return; } KJ_FAIL_REQUIRE(__VA_ARGS__) { isValid = false; return; }
void validate(schema::FileNode::Reader fileNode) { void validate(schema::FileNode::Reader fileNode) {
// Nothing needs validation. // Nothing needs validation.
...@@ -472,9 +472,9 @@ private: ...@@ -472,9 +472,9 @@ private:
Compatibility compatibility; Compatibility compatibility;
#define VALIDATE_SCHEMA(condition, ...) \ #define VALIDATE_SCHEMA(condition, ...) \
VALIDATE_INPUT(condition, ##__VA_ARGS__) { compatibility = INCOMPATIBLE; return; } KJ_REQUIRE(condition, ##__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
#define FAIL_VALIDATE_SCHEMA(...) \ #define FAIL_VALIDATE_SCHEMA(...) \
FAIL_VALIDATE_INPUT(__VA_ARGS__) { compatibility = INCOMPATIBLE; return; } KJ_FAIL_REQUIRE(__VA_ARGS__) { compatibility = INCOMPATIBLE; return; }
void replacementIsNewer() { void replacementIsNewer() {
switch (compatibility) { switch (compatibility) {
...@@ -934,7 +934,7 @@ private: ...@@ -934,7 +934,7 @@ private:
schema::Value::Reader replacement) { schema::Value::Reader replacement) {
// Note that we test default compatibility only after testing type compatibility, and default // Note that we test default compatibility only after testing type compatibility, and default
// values have already been validated as matching their types, so this should pass. // values have already been validated as matching their types, so this should pass.
RECOVERABLE_ASSERT(value.getBody().which() == replacement.getBody().which()) { KJ_ASSERT(value.getBody().which() == replacement.getBody().which()) {
compatibility = INCOMPATIBLE; compatibility = INCOMPATIBLE;
return; return;
} }
......
...@@ -47,7 +47,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { ...@@ -47,7 +47,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes; uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes;
kj::ArrayPtr<const byte> buffer = inner.getReadBuffer(); kj::ArrayPtr<const byte> buffer = inner.getReadBuffer();
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") { KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") {
return minBytes; // garbage return minBytes; // garbage
} }
const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin()); const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
...@@ -55,7 +55,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { ...@@ -55,7 +55,7 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
#define REFRESH_BUFFER() \ #define REFRESH_BUFFER() \
inner.skip(buffer.size()); \ inner.skip(buffer.size()); \
buffer = inner.getReadBuffer(); \ buffer = inner.getReadBuffer(); \
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") { \ KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \
return minBytes; /* garbage */ \ return minBytes; /* garbage */ \
} \ } \
in = reinterpret_cast<const uint8_t*>(buffer.begin()) in = reinterpret_cast<const uint8_t*>(buffer.begin())
...@@ -126,8 +126,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { ...@@ -126,8 +126,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint runLength = *in++ * sizeof(word); uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= outEnd - out, KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") { "Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
} }
memset(out, 0, runLength); memset(out, 0, runLength);
...@@ -138,8 +138,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { ...@@ -138,8 +138,8 @@ size_t PackedInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
uint runLength = *in++ * sizeof(word); uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= outEnd - out, KJ_REQUIRE(runLength <= outEnd - out,
"Packed input did not end cleanly on a segment boundary.") { "Packed input did not end cleanly on a segment boundary.") {
return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage return std::max<size_t>(minBytes, out - reinterpret_cast<uint8_t*>(dst)); // garbage
} }
...@@ -198,7 +198,7 @@ void PackedInputStream::skip(size_t bytes) { ...@@ -198,7 +198,7 @@ void PackedInputStream::skip(size_t bytes) {
#define REFRESH_BUFFER() \ #define REFRESH_BUFFER() \
inner.skip(buffer.size()); \ inner.skip(buffer.size()); \
buffer = inner.getReadBuffer(); \ buffer = inner.getReadBuffer(); \
VALIDATE_INPUT(buffer.size() > 0, "Premature end of packed input.") return; \ KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \
in = reinterpret_cast<const uint8_t*>(buffer.begin()) in = reinterpret_cast<const uint8_t*>(buffer.begin())
for (;;) { for (;;) {
...@@ -252,8 +252,7 @@ void PackedInputStream::skip(size_t bytes) { ...@@ -252,8 +252,7 @@ void PackedInputStream::skip(size_t bytes) {
uint runLength = *in++ * sizeof(word); uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= bytes, KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
"Packed input did not end cleanly on a segment boundary.") {
return; return;
} }
...@@ -264,8 +263,7 @@ void PackedInputStream::skip(size_t bytes) { ...@@ -264,8 +263,7 @@ void PackedInputStream::skip(size_t bytes) {
uint runLength = *in++ * sizeof(word); uint runLength = *in++ * sizeof(word);
VALIDATE_INPUT(runLength <= bytes, KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
"Packed input did not end cleanly on a segment boundary.") {
return; return;
} }
......
...@@ -106,11 +106,12 @@ void SnappyInputStream::skip(size_t bytes) { ...@@ -106,11 +106,12 @@ void SnappyInputStream::skip(size_t bytes) {
void SnappyInputStream::refill() { void SnappyInputStream::refill() {
uint32_t length = 0; uint32_t length = 0;
InputStreamSnappySource snappySource(inner); InputStreamSnappySource snappySource(inner);
VALIDATE_INPUT( KJ_REQUIRE(
snappy::RawUncompress( snappy::RawUncompress(
&snappySource, reinterpret_cast<char*>(buffer.begin()), buffer.size(), &length), &snappySource, reinterpret_cast<char*>(buffer.begin()), buffer.size(), &length),
"Snappy decompression failed.") { "Snappy decompression failed.") {
length = 1; // garbage length = 1; // garbage
break;
} }
bufferAvailable = buffer.slice(0, length); bufferAvailable = buffer.slice(0, length);
......
...@@ -42,7 +42,7 @@ FlatArrayMessageReader::FlatArrayMessageReader( ...@@ -42,7 +42,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(
uint segmentCount = table[0].get() + 1; uint segmentCount = table[0].get() + 1;
size_t offset = segmentCount / 2u + 1u; size_t offset = segmentCount / 2u + 1u;
VALIDATE_INPUT(array.size() >= offset, "Message ends prematurely in segment table.") { KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") {
return; return;
} }
...@@ -52,8 +52,8 @@ FlatArrayMessageReader::FlatArrayMessageReader( ...@@ -52,8 +52,8 @@ FlatArrayMessageReader::FlatArrayMessageReader(
uint segmentSize = table[1].get(); uint segmentSize = table[1].get();
VALIDATE_INPUT(array.size() >= offset + segmentSize, KJ_REQUIRE(array.size() >= offset + segmentSize,
"Message ends prematurely in first segment.") { "Message ends prematurely in first segment.") {
return; return;
} }
...@@ -66,7 +66,7 @@ FlatArrayMessageReader::FlatArrayMessageReader( ...@@ -66,7 +66,7 @@ FlatArrayMessageReader::FlatArrayMessageReader(
for (uint i = 1; i < segmentCount; i++) { for (uint i = 1; i < segmentCount; i++) {
uint segmentSize = table[i + 1].get(); uint segmentSize = table[i + 1].get();
VALIDATE_INPUT(array.size() >= offset + segmentSize, "Message ends prematurely.") { KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") {
moreSegments = nullptr; moreSegments = nullptr;
return; return;
} }
...@@ -142,9 +142,10 @@ InputStreamMessageReader::InputStreamMessageReader( ...@@ -142,9 +142,10 @@ InputStreamMessageReader::InputStreamMessageReader(
size_t totalWords = segment0Size; size_t totalWords = segment0Size;
// Reject messages with too many segments for security reasons. // Reject messages with too many segments for security reasons.
VALIDATE_INPUT(segmentCount < 512, "Message has too many segments.") { KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") {
segmentCount = 1; segmentCount = 1;
segment0Size = 1; segment0Size = 1;
break;
} }
// Read sizes for all segments except the first. Include padding if necessary. // Read sizes for all segments except the first. Include padding if necessary.
...@@ -159,12 +160,13 @@ InputStreamMessageReader::InputStreamMessageReader( ...@@ -159,12 +160,13 @@ InputStreamMessageReader::InputStreamMessageReader(
// Don't accept a message which the receiver couldn't possibly traverse without hitting the // Don't accept a message which the receiver couldn't possibly traverse without hitting the
// traversal limit. Without this check, a malicious client could transmit a very large segment // traversal limit. Without this check, a malicious client could transmit a very large segment
// size to make the receiver allocate excessive space and possibly crash. // size to make the receiver allocate excessive space and possibly crash.
VALIDATE_INPUT(totalWords <= options.traversalLimitInWords, KJ_REQUIRE(totalWords <= options.traversalLimitInWords,
"Message is too large. To increase the limit on the receiving end, see " "Message is too large. To increase the limit on the receiving end, see "
"capnproto::ReaderOptions.") { "capnproto::ReaderOptions.") {
segmentCount = 1; segmentCount = 1;
segment0Size = std::min<size_t>(segment0Size, options.traversalLimitInWords); segment0Size = std::min<size_t>(segment0Size, options.traversalLimitInWords);
totalWords = segment0Size; totalWords = segment0Size;
break;
} }
if (scratchSpace.size() < totalWords) { if (scratchSpace.size() < totalWords) {
......
...@@ -162,7 +162,9 @@ static void print(std::ostream& os, DynamicValue::Reader value, ...@@ -162,7 +162,9 @@ static void print(std::ostream& os, DynamicValue::Reader value,
break; break;
} }
case DynamicValue::INTERFACE: case DynamicValue::INTERFACE:
FAIL_RECOVERABLE_ASSERT("Don't know how to print interfaces.") {} KJ_FAIL_ASSERT("Don't know how to print interfaces.") {
break;
}
break; break;
case DynamicValue::OBJECT: case DynamicValue::OBJECT:
os << "(opaque object)"; os << "(opaque object)";
......
...@@ -30,9 +30,11 @@ namespace internal { ...@@ -30,9 +30,11 @@ namespace internal {
void inlineRequireFailure(const char* file, int line, const char* expectation, void inlineRequireFailure(const char* file, int line, const char* expectation,
const char* macroArgs, const char* message) { const char* macroArgs, const char* message) {
if (message == nullptr) { if (message == nullptr) {
Log::fatalFault(file, line, Exception::Nature::PRECONDITION, expectation, macroArgs); Log::Fault f(file, line, Exception::Nature::PRECONDITION, 0, expectation, macroArgs);
f.fatal();
} else { } else {
Log::fatalFault(file, line, Exception::Nature::PRECONDITION, expectation, macroArgs, message); Log::Fault f(file, line, Exception::Nature::PRECONDITION, 0, expectation, macroArgs, message);
f.fatal();
} }
} }
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#ifndef KJ_COMMON_H_ #ifndef KJ_COMMON_H_
#define KJ_COMMON_H_ #define KJ_COMMON_H_
#if __cplusplus < 201103L #if __cplusplus < 201103L && !__CDT_PARSER__
#error "This code requires C++11. Either your compiler does not support it or it is not enabled." #error "This code requires C++11. Either your compiler does not support it or it is not enabled."
#ifdef __GNUC__ #ifdef __GNUC__
// Compiler claims compatibility with GCC, so presumably supports -std. // Compiler claims compatibility with GCC, so presumably supports -std.
......
...@@ -32,9 +32,8 @@ namespace kj { ...@@ -32,9 +32,8 @@ namespace kj {
ArrayPtr<const char> KJ_STRINGIFY(Exception::Nature nature) { ArrayPtr<const char> KJ_STRINGIFY(Exception::Nature nature) {
static const char* NATURE_STRINGS[] = { static const char* NATURE_STRINGS[] = {
"precondition not met", "requirement not met",
"bug in code", "bug in code",
"invalid input data",
"error from OS", "error from OS",
"network failure", "network failure",
"error" "error"
...@@ -174,7 +173,7 @@ void ExceptionCallback::logMessage(StringPtr text) { ...@@ -174,7 +173,7 @@ void ExceptionCallback::logMessage(StringPtr text) {
} }
void ExceptionCallback::useProcessWide() { void ExceptionCallback::useProcessWide() {
RECOVERABLE_REQUIRE(globalCallback == nullptr, KJ_REQUIRE(globalCallback == nullptr,
"Can't register multiple global ExceptionCallbacks at once.") { "Can't register multiple global ExceptionCallbacks at once.") {
return; return;
} }
......
...@@ -49,7 +49,6 @@ public: ...@@ -49,7 +49,6 @@ public:
PRECONDITION, PRECONDITION,
LOCAL_BUG, LOCAL_BUG,
INPUT,
OS_ERROR, OS_ERROR,
NETWORK_FAILURE, NETWORK_FAILURE,
OTHER OTHER
......
...@@ -188,8 +188,9 @@ ArrayPtr<const byte> ArrayInputStream::getReadBuffer() { ...@@ -188,8 +188,9 @@ ArrayPtr<const byte> ArrayInputStream::getReadBuffer() {
size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
size_t n = std::min(maxBytes, array.size()); size_t n = std::min(maxBytes, array.size());
size_t result = n; size_t result = n;
VALIDATE_INPUT(n >= minBytes, "ArrayInputStream ended prematurely.") { KJ_REQUIRE(n >= minBytes, "ArrayInputStream ended prematurely.") {
result = minBytes; // garbage result = minBytes; // garbage
break;
} }
memcpy(dst, array.begin(), n); memcpy(dst, array.begin(), n);
array = array.slice(n, array.size()); array = array.slice(n, array.size());
...@@ -197,8 +198,9 @@ size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) { ...@@ -197,8 +198,9 @@ size_t ArrayInputStream::read(void* dst, size_t minBytes, size_t maxBytes) {
} }
void ArrayInputStream::skip(size_t bytes) { void ArrayInputStream::skip(size_t bytes) {
VALIDATE_INPUT(array.size() >= bytes, "ArrayInputStream ended prematurely.") { KJ_REQUIRE(array.size() >= bytes, "ArrayInputStream ended prematurely.") {
bytes = array.size(); bytes = array.size();
break;
} }
array = array.slice(bytes, array.size()); array = array.slice(bytes, array.size());
} }
...@@ -228,7 +230,9 @@ void ArrayOutputStream::write(const void* src, size_t size) { ...@@ -228,7 +230,9 @@ void ArrayOutputStream::write(const void* src, size_t size) {
AutoCloseFd::~AutoCloseFd() { AutoCloseFd::~AutoCloseFd() {
if (fd >= 0 && close(fd) < 0) { if (fd >= 0 && close(fd) < 0) {
FAIL_RECOVERABLE_SYSCALL("close", errno, fd); FAIL_SYSCALL("close", errno, fd) {
break;
}
} }
} }
...@@ -240,8 +244,9 @@ size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) { ...@@ -240,8 +244,9 @@ size_t FdInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) {
byte* max = pos + maxBytes; byte* max = pos + maxBytes;
while (pos < min) { while (pos < min) {
ssize_t n = KJ_SYSCALL(::read(fd, pos, max - pos), fd); ssize_t n;
VALIDATE_INPUT(n > 0, "Premature EOF") { KJ_SYSCALL(n = ::read(fd, pos, max - pos), fd);
KJ_REQUIRE(n > 0, "Premature EOF") {
return minBytes; return minBytes;
} }
pos += n; pos += n;
...@@ -256,7 +261,8 @@ void FdOutputStream::write(const void* buffer, size_t size) { ...@@ -256,7 +261,8 @@ void FdOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer); const char* pos = reinterpret_cast<const char*>(buffer);
while (size > 0) { while (size > 0) {
ssize_t n = KJ_SYSCALL(::write(fd, pos, size), fd); ssize_t n;
KJ_SYSCALL(n = ::write(fd, pos, size), fd);
KJ_ASSERT(n > 0, "write() returned zero."); KJ_ASSERT(n > 0, "write() returned zero.");
pos += n; pos += n;
size -= n; size -= n;
...@@ -280,7 +286,8 @@ void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) { ...@@ -280,7 +286,8 @@ void FdOutputStream::write(ArrayPtr<const ArrayPtr<const byte>> pieces) {
} }
while (current < iov.end()) { while (current < iov.end()) {
ssize_t n = KJ_SYSCALL(::writev(fd, current, iov.end() - current), fd); ssize_t n;
KJ_SYSCALL(n = ::writev(fd, current, iov.end() - current), fd);
KJ_ASSERT(n > 0, "writev() returned zero."); KJ_ASSERT(n > 0, "writev() returned zero.");
while (static_cast<size_t>(n) >= current->iov_len) { while (static_cast<size_t>(n) >= current->iov_len) {
......
...@@ -101,18 +101,39 @@ TEST(Logging, Log) { ...@@ -101,18 +101,39 @@ TEST(Logging, Log) {
mockCallback.text); mockCallback.text);
mockCallback.text.clear(); mockCallback.text.clear();
KJ_DBG("Some debug text."); line = __LINE__;
EXPECT_EQ("log message: debug: " + fileLine(__FILE__, line) + ": Some debug text.\n",
mockCallback.text);
mockCallback.text.clear();
// INFO logging is disabled by default.
KJ_LOG(INFO, "Info."); line = __LINE__;
EXPECT_EQ("", mockCallback.text);
mockCallback.text.clear();
// Enable it.
Log::setLogLevel(Log::Severity::INFO);
KJ_LOG(INFO, "Some text."); line = __LINE__;
EXPECT_EQ("log message: info: " + fileLine(__FILE__, line) + ": Some text.\n",
mockCallback.text);
mockCallback.text.clear();
// Back to default.
Log::setLogLevel(Log::Severity::WARNING);
KJ_ASSERT(1 == 1); KJ_ASSERT(1 == 1);
EXPECT_THROW(KJ_ASSERT(1 == 2), MockException); line = __LINE__; EXPECT_THROW(KJ_ASSERT(1 == 2), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: expected " EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: expected "
"1 == 2\n", mockCallback.text); "1 == 2\n", mockCallback.text);
mockCallback.text.clear(); mockCallback.text.clear();
RECOVERABLE_ASSERT(1 == 1) { KJ_ASSERT(1 == 1) {
ADD_FAILURE() << "Shouldn't call recovery code when check passes."; ADD_FAILURE() << "Shouldn't call recovery code when check passes.";
break;
}; };
bool recovered = false; bool recovered = false;
RECOVERABLE_ASSERT(1 == 2, "1 is not 2") { recovered = true; } line = __LINE__; KJ_ASSERT(1 == 2, "1 is not 2") { recovered = true; break; } line = __LINE__;
EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": bug in code: expected " EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": bug in code: expected "
"1 == 2; 1 is not 2\n", mockCallback.text); "1 == 2; 1 is not 2\n", mockCallback.text);
EXPECT_TRUE(recovered); EXPECT_TRUE(recovered);
...@@ -124,11 +145,11 @@ TEST(Logging, Log) { ...@@ -124,11 +145,11 @@ TEST(Logging, Log) {
mockCallback.text.clear(); mockCallback.text.clear();
EXPECT_THROW(KJ_REQUIRE(1 == 2, i, "hi", str), MockException); line = __LINE__; EXPECT_THROW(KJ_REQUIRE(1 == 2, i, "hi", str), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": precondition not met: expected " EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": requirement not met: expected "
"1 == 2; i = 123; hi; str = foo\n", mockCallback.text); "1 == 2; i = 123; hi; str = foo\n", mockCallback.text);
mockCallback.text.clear(); mockCallback.text.clear();
EXPECT_THROW(KJ_ASSERT(false, "foo"), MockException); line = __LINE__; EXPECT_THROW(KJ_FAIL_ASSERT("foo"), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: foo\n", EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": bug in code: foo\n",
mockCallback.text); mockCallback.text);
mockCallback.text.clear(); mockCallback.text.clear();
...@@ -142,7 +163,8 @@ TEST(Logging, Syscall) { ...@@ -142,7 +163,8 @@ TEST(Logging, Syscall) {
int i = 123; int i = 123;
const char* str = "foo"; const char* str = "foo";
int fd = KJ_SYSCALL(dup(STDIN_FILENO)); int fd;
KJ_SYSCALL(fd = dup(STDIN_FILENO));
KJ_SYSCALL(close(fd)); KJ_SYSCALL(close(fd));
EXPECT_THROW(KJ_SYSCALL(close(fd), i, "bar", str), MockException); line = __LINE__; EXPECT_THROW(KJ_SYSCALL(close(fd), i, "bar", str), MockException); line = __LINE__;
EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): " EXPECT_EQ("fatal exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): "
...@@ -151,7 +173,7 @@ TEST(Logging, Syscall) { ...@@ -151,7 +173,7 @@ TEST(Logging, Syscall) {
int result = 0; int result = 0;
bool recovered = false; bool recovered = false;
RECOVERABLE_SYSCALL(result = close(fd), i, "bar", str) { recovered = true; } line = __LINE__; KJ_SYSCALL(result = close(fd), i, "bar", str) { recovered = true; break; } line = __LINE__;
EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): " EXPECT_EQ("recoverable exception: " + fileLine(__FILE__, line) + ": error from OS: close(fd): "
+ strerror(EBADF) + "; i = 123; bar; str = foo\n", mockCallback.text); + strerror(EBADF) + "; i = 123; bar; str = foo\n", mockCallback.text);
EXPECT_LT(result, 0); EXPECT_LT(result, 0);
......
...@@ -29,14 +29,15 @@ ...@@ -29,14 +29,15 @@
namespace kj { namespace kj {
Log::Severity Log::minSeverity = Log::Severity::INFO; Log::Severity Log::minSeverity = Log::Severity::WARNING;
ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity) { ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity) {
static const char* SEVERITY_STRINGS[] = { static const char* SEVERITY_STRINGS[] = {
"info", "info",
"warning", "warning",
"error", "error",
"fatal" "fatal",
"debug"
}; };
const char* s = SEVERITY_STRINGS[static_cast<uint>(severity)]; const char* s = SEVERITY_STRINGS[static_cast<uint>(severity)];
...@@ -110,6 +111,10 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro ...@@ -110,6 +111,10 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro
} }
} }
if (style == ASSERTION && code == nullptr) {
style = LOG;
}
{ {
StringPtr expected = "expected "; StringPtr expected = "expected ";
StringPtr codeArray = style == LOG ? nullptr : StringPtr(code); StringPtr codeArray = style == LOG ? nullptr : StringPtr(code);
...@@ -117,11 +122,6 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro ...@@ -117,11 +122,6 @@ static String makeDescription(DescriptionStyle style, const char* code, int erro
StringPtr delim = "; "; StringPtr delim = "; ";
StringPtr colon = ": "; StringPtr colon = ": ";
if (style == ASSERTION && strcmp(code, "false") == 0) {
// Don't print "expected false", that's silly.
style = LOG;
}
StringPtr sysErrorArray; StringPtr sysErrorArray;
#if __USE_GNU #if __USE_GNU
char buffer[256]; char buffer[256];
...@@ -194,38 +194,28 @@ void Log::logInternal(const char* file, int line, Severity severity, const char* ...@@ -194,38 +194,28 @@ void Log::logInternal(const char* file, int line, Severity severity, const char*
makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n')); makeDescription(LOG, nullptr, 0, macroArgs, argValues), '\n'));
} }
void Log::recoverableFaultInternal( Log::Fault::~Fault() noexcept(false) {
const char* file, int line, Exception::Nature nature, if (exception != nullptr) {
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) { Exception copy = mv(*exception);
getExceptionCallback().onRecoverableException( delete exception;
Exception(nature, Exception::Durability::PERMANENT, file, line, getExceptionCallback().onRecoverableException(mv(copy));
makeDescription(ASSERTION, condition, 0, macroArgs, argValues))); }
} }
void Log::fatalFaultInternal( void Log::Fault::fatal() {
const char* file, int line, Exception::Nature nature, Exception copy = mv(*exception);
const char* condition, const char* macroArgs, ArrayPtr<String> argValues) { delete exception;
getExceptionCallback().onFatalException( exception = nullptr;
Exception(nature, Exception::Durability::PERMANENT, file, line, getExceptionCallback().onFatalException(mv(copy));
makeDescription(ASSERTION, condition, 0, macroArgs, argValues)));
abort(); abort();
} }
void Log::recoverableFailedSyscallInternal( void Log::Fault::init(
const char* file, int line, const char* call, const char* file, int line, Exception::Nature nature, int errorNumber,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues) { const char* condition, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onRecoverableException( exception = new Exception(nature, Exception::Durability::PERMANENT, file, line,
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line, makeDescription(nature == Exception::Nature::OS_ERROR ? SYSCALL : ASSERTION,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues))); condition, errorNumber, macroArgs, argValues));
}
void Log::fatalFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues) {
getExceptionCallback().onFatalException(
Exception(Exception::Nature::OS_ERROR, Exception::Durability::PERMANENT, file, line,
makeDescription(SYSCALL, call, errorNumber, macroArgs, argValues)));
abort();
} }
void Log::addContextToInternal(Exception& exception, const char* file, int line, void Log::addContextToInternal(Exception& exception, const char* file, int line,
......
...@@ -36,43 +36,48 @@ ...@@ -36,43 +36,48 @@
// //
// * `KJ_LOG(severity, ...)`: Just writes a log message, to stderr by default (but you can // * `KJ_LOG(severity, ...)`: Just writes a log message, to stderr by default (but you can
// intercept messages by implementing an ExceptionCallback). `severity` is `INFO`, `WARNING`, // intercept messages by implementing an ExceptionCallback). `severity` is `INFO`, `WARNING`,
// `ERROR`, or `FATAL`. If the severity is not higher than the global logging threshold, nothing // `ERROR`, or `FATAL`. By default, `INFO` logs are not written, but for command-line apps the
// will be written and in fact the log message won't even be evaluated. // user should be able to pass a flag like `--verbose` to enable them. Other log levels are
// enabled by default. Log messages -- like exceptions -- can be intercepted by registering an
// ExceptionCallback.
//
// * `KJ_DBG(...)`: Like `KJ_LOG`, but intended specifically for temporary log lines added while
// debugging a particular problem. Calls to `KJ_DBG` should always be deleted before committing
// code. It is suggested that you set up a pre-commit hook that checks for this.
// //
// * `KJ_ASSERT(condition, ...)`: Throws an exception if `condition` is false, or aborts if // * `KJ_ASSERT(condition, ...)`: Throws an exception if `condition` is false, or aborts if
// exceptions are disabled. This macro should be used to check for bugs in the surrounding code // exceptions are disabled. This macro should be used to check for bugs in the surrounding code
// and its dependencies, but NOT to check for invalid input. // and its dependencies, but NOT to check for invalid input. The macro may be followed by a
// brace-delimited code block; if so, the block will be executed in the case where the assertion
// fails, before throwing the exception. If control jumps out of the block (e.g. with "break",
// "return", or "goto"), then the error is considered "recoverable" -- in this case, if
// exceptions are disabled, execution will continue normally rather than aborting (but if
// exceptions are enabled, an exception will still be thrown on exiting the block). A "break"
// statement in particular will jump to the code immediately after the block (it does not break
// any surrounding loop or switch). Example:
//
// KJ_ASSERT(value >= 0, "Value cannot be negative.", value) {
// // Assertion failed. Set value to zero to "recover".
// value = 0;
// // Don't abort if exceptions are disabled. Continue normally.
// // (Still throw an exception if they are enabled, though.)
// break;
// }
// // When exceptions are disabled, we'll get here even if the assertion fails.
// // Otherwise, we get here only if the assertion passes.
// //
// * `KJ_REQUIRE(condition, ...)`: Like `KJ_ASSERT` but used to check preconditions -- e.g. to // * `KJ_REQUIRE(condition, ...)`: Like `KJ_ASSERT` but used to check preconditions -- e.g. to
// validate parameters passed from a caller. A failure indicates that the caller is buggy. // validate parameters passed from a caller. A failure indicates that the caller is buggy.
// //
// * `RECOVERABLE_ASSERT(condition, ...) { ... }`: Like `KJ_ASSERT` except that if exceptions are // * `KJ_SYSCALL(code, ...)`: Executes `code` assuming it makes a system call. A negative result
// disabled, instead of aborting, the following code block will be executed. This block should // is considered an error, with error code reported via `errno`. EINTR is handled by retrying.
// do whatever it can to fill in dummy values so that the code can continue executing, even if // Other errors are handled by throwing an exception. If you need to examine the return code,
// this means the eventual output will be garbage. // assign it to a variable like so:
//
// * `RECOVERABLE_REQUIRE(condition, ...) { ... }`: Like `RECOVERABLE_ASSERT` and `KJ_REQUIRE`.
//
// * `VALIDATE_INPUT(condition, ...) { ... }`: Like `RECOVERABLE_PRECOND` but used to validate
// input that may have come from the user or some other untrusted source. Recoverability is
// required in this case.
//
// * `KJ_SYSCALL(code, ...)`: Executes `code` assuming it makes a system call. A negative return
// value is considered an error. EINTR is handled by retrying. Other errors are handled by
// throwing an exception. The macro also returns the call's result. For example, the following
// calls `open()` and includes the file name in any error message:
//
// int fd = KJ_SYSCALL(open(filename, O_RDONLY), filename);
//
// * `RECOVERABLE_SYSCALL(code, ...) { ... }`: Like `RECOVERABLE_ASSERT` and `SYSCALL`. Note that
// unfortunately this macro cannot return a value since it implements control flow, but you can
// assign to a variable *inside* the parameter instead:
// //
// int fd; // int fd;
// RECOVERABLE_SYSCALL(fd = open(filename, O_RDONLY), filename) { // KJ_SYSCALL(fd = open(filename, O_RDONLY), filename);
// // Failed. Open /dev/null instead. //
// fd = SYSCALL(open("/dev/null", O_RDONLY)); // `KJ_SYSCALL` can be followed by a recovery block, just like `KJ_ASSERT`.
// }
// //
// * `KJ_CONTEXT(...)`: Notes additional contextual information relevant to any exceptions thrown // * `KJ_CONTEXT(...)`: Notes additional contextual information relevant to any exceptions thrown
// from within the current scope. That is, until control exits the block in which KJ_CONTEXT() // from within the current scope. That is, until control exits the block in which KJ_CONTEXT()
...@@ -102,12 +107,17 @@ ...@@ -102,12 +107,17 @@
namespace kj { namespace kj {
class Log { class Log {
// Mostly-internal
public: public:
enum class Severity { enum class Severity {
INFO, // Information useful for debugging. No problem detected. INFO, // Information describing what the code is up to, which users may request to see
// with a flag like `--verbose`. Does not indicate a problem. Not printed by
// default; you must call setLogLevel(INFO) to enable.
WARNING, // A problem was detected but execution can continue with correct output. WARNING, // A problem was detected but execution can continue with correct output.
ERROR, // Something is wrong, but execution can continue with garbage output. ERROR, // Something is wrong, but execution can continue with garbage output.
FATAL // Something went wrong, and execution cannot continue. FATAL, // Something went wrong, and execution cannot continue.
DEBUG // Temporary debug logging. See KJ_DBG.
// Make sure to update the stringifier if you add a new severity level. // Make sure to update the stringifier if you add a new severity level.
}; };
...@@ -122,32 +132,35 @@ public: ...@@ -122,32 +132,35 @@ public:
static void log(const char* file, int line, Severity severity, const char* macroArgs, static void log(const char* file, int line, Severity severity, const char* macroArgs,
Params&&... params); Params&&... params);
template <typename... Params> class Fault {
static void recoverableFault(const char* file, int line, Exception::Nature nature, public:
const char* condition, const char* macroArgs, Params&&... params); template <typename... Params>
Fault(const char* file, int line, Exception::Nature nature, int errorNumber,
const char* condition, const char* macroArgs, Params&&... params);
~Fault() noexcept(false);
template <typename... Params> void fatal() KJ_NORETURN;
static void fatalFault(const char* file, int line, Exception::Nature nature, // Throw the exception.
const char* condition, const char* macroArgs, Params&&... params)
KJ_NORETURN;
template <typename Call, typename... Params> private:
static bool recoverableSyscall(Call&& call, const char* file, int line, const char* callText, void init(const char* file, int line, Exception::Nature nature, int errorNumber,
const char* macroArgs, Params&&... params); const char* condition, const char* macroArgs, ArrayPtr<String> argValues);
template <typename Call, typename... Params> Exception* exception;
static auto syscall(Call&& call, const char* file, int line, const char* callText, };
const char* macroArgs, Params&&... params) -> decltype(call());
template <typename... Params> class SyscallResult {
static void reportFailedRecoverableSyscall( public:
int errorNumber, const char* file, int line, const char* callText, inline SyscallResult(int errorNumber): errorNumber(errorNumber) {}
const char* macroArgs, Params&&... params); inline operator void*() { return errorNumber == 0 ? this : nullptr; }
inline int getErrorNumber() { return errorNumber; }
template <typename... Params> private:
static void reportFailedSyscall( int errorNumber;
int errorNumber, const char* file, int line, const char* callText, };
const char* macroArgs, Params&&... params);
template <typename Call>
static SyscallResult syscall(Call&& call);
class Context: public ExceptionCallback { class Context: public ExceptionCallback {
public: public:
...@@ -187,20 +200,6 @@ private: ...@@ -187,20 +200,6 @@ private:
static void logInternal(const char* file, int line, Severity severity, const char* macroArgs, static void logInternal(const char* file, int line, Severity severity, const char* macroArgs,
ArrayPtr<String> argValues); ArrayPtr<String> argValues);
static void recoverableFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues);
static void fatalFaultInternal(
const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, ArrayPtr<String> argValues)
KJ_NORETURN;
static void recoverableFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues);
static void fatalFailedSyscallInternal(
const char* file, int line, const char* call,
int errorNumber, const char* macroArgs, ArrayPtr<String> argValues)
KJ_NORETURN;
static void addContextToInternal(Exception& exception, const char* file, int line, static void addContextToInternal(Exception& exception, const char* file, int line,
const char* macroArgs, ArrayPtr<String> argValues); const char* macroArgs, ArrayPtr<String> argValues);
...@@ -215,53 +214,33 @@ ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity); ...@@ -215,53 +214,33 @@ ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity);
::kj::Log::log(__FILE__, __LINE__, ::kj::Log::Severity::severity, \ ::kj::Log::log(__FILE__, __LINE__, ::kj::Log::Severity::severity, \
#__VA_ARGS__, __VA_ARGS__) #__VA_ARGS__, __VA_ARGS__)
#define KJ_FAULT(nature, cond, ...) \ #define KJ_DBG(...) KJ_LOG(DEBUG, ##__VA_ARGS__)
if (KJ_EXPECT_TRUE(cond)) {} else \
::kj::Log::fatalFault(__FILE__, __LINE__, \
::kj::Exception::Nature::nature, #cond, #__VA_ARGS__, ##__VA_ARGS__)
#define RECOVERABLE_FAULT(nature, cond, ...) \ #define _kJ_FAULT(nature, cond, ...) \
if (KJ_EXPECT_TRUE(cond)) {} else \ if (KJ_EXPECT_TRUE(cond)) {} else \
if (::kj::Log::recoverableFault(__FILE__, __LINE__, \ for (::kj::Log::Fault f(__FILE__, __LINE__, ::kj::Exception::Nature::nature, 0, \
::kj::Exception::Nature::nature, #cond, #__VA_ARGS__, ##__VA_ARGS__), false) {} \ #cond, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
else
#define KJ_ASSERT(...) KJ_FAULT(LOCAL_BUG, __VA_ARGS__)
#define RECOVERABLE_ASSERT(...) RECOVERABLE_FAULT(LOCAL_BUG, __VA_ARGS__)
#define KJ_REQUIRE(...) KJ_FAULT(PRECONDITION, __VA_ARGS__)
#define RECOVERABLE_REQUIRE(...) RECOVERABLE_FAULT(PRECONDITION, __VA_ARGS__)
#define VALIDATE_INPUT(...) RECOVERABLE_FAULT(INPUT, __VA_ARGS__)
#define KJ_FAIL_ASSERT(...) KJ_ASSERT(false, ##__VA_ARGS__)
#define FAIL_RECOVERABLE_ASSERT(...) RECOVERABLE_ASSERT(false, ##__VA_ARGS__)
#define KJ_FAIL_REQUIRE(...) KJ_REQUIRE(false, ##__VA_ARGS__)
#define FAIL_RECOVERABLE_REQUIRE(...) RECOVERABLE_REQUIRE(false, ##__VA_ARGS__)
#define FAIL_VALIDATE_INPUT(...) VALIDATE_INPUT(false, ##__VA_ARGS__)
#define KJ_SYSCALL(call, ...) \ #define _kJ_FAIL_FAULT(nature, ...) \
::kj::Log::syscall( \ for (::kj::Log::Fault f(__FILE__, __LINE__, ::kj::Exception::Nature::nature, 0, \
[&](){return (call);}, __FILE__, __LINE__, #call, #__VA_ARGS__, ##__VA_ARGS__) nullptr, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#define RECOVERABLE_SYSCALL(call, ...) \ #define KJ_ASSERT(...) _kJ_FAULT(LOCAL_BUG, ##__VA_ARGS__)
if (::kj::Log::recoverableSyscall( \ #define KJ_REQUIRE(...) _kJ_FAULT(PRECONDITION, ##__VA_ARGS__)
[&](){return (call);}, __FILE__, __LINE__, #call, #__VA_ARGS__, ##__VA_ARGS__)) {} \
else #define KJ_FAIL_ASSERT(...) _kJ_FAIL_FAULT(LOCAL_BUG, ##__VA_ARGS__)
#define KJ_FAIL_REQUIRE(...) _kJ_FAIL_FAULT(PRECONDITION, ##__VA_ARGS__)
#define KJ_SYSCALL(call, ...) \
if (auto _kjSyscallResult = ::kj::Log::syscall([&](){return (call);})) {} else \
for (::kj::Log::Fault f( \
__FILE__, __LINE__, ::kj::Exception::Nature::OS_ERROR, \
_kjSyscallResult.getErrorNumber(), #call, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
#define FAIL_SYSCALL(code, errorNumber, ...) \ #define FAIL_SYSCALL(code, errorNumber, ...) \
do { \ for (::kj::Log::Fault f( \
/* make sure to read error number before doing anything else that could change it */ \ __FILE__, __LINE__, ::kj::Exception::Nature::OS_ERROR, \
int _errorNumber = errorNumber; \ errorNumber, code, #__VA_ARGS__, ##__VA_ARGS__);; f.fatal())
::kj::Log::reportFailedSyscall( \
_errorNumber, __FILE__, __LINE__, #code, #__VA_ARGS__, ##__VA_ARGS__); \
} while (false)
#define FAIL_RECOVERABLE_SYSCALL(code, errorNumber, ...) \
do { \
/* make sure to read error number before doing anything else that could change it */ \
int _errorNumber = errorNumber; \
::kj::Log::reportFailedRecoverableSyscall( \
_errorNumber, __FILE__, __LINE__, #code, #__VA_ARGS__, ##__VA_ARGS__); \
} while (false)
#define KJ_CONTEXT(...) \ #define KJ_CONTEXT(...) \
auto _kjContextFunc = [&](::kj::Exception& exception) { \ auto _kjContextFunc = [&](::kj::Exception& exception) { \
...@@ -273,15 +252,11 @@ ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity); ...@@ -273,15 +252,11 @@ ArrayPtr<const char> KJ_STRINGIFY(Log::Severity severity);
#ifdef NDEBUG #ifdef NDEBUG
#define KJ_DLOG(...) do {} while (false) #define KJ_DLOG(...) do {} while (false)
#define KJ_DASSERT(...) do {} while (false) #define KJ_DASSERT(...) do {} while (false)
#define KJ_RECOVERABLE_DASSERT(...) do {} while (false)
#define KJ_DREQUIRE(...) do {} while (false) #define KJ_DREQUIRE(...) do {} while (false)
#define KJ_RECOVERABLE_DREQUIRE(...) do {} while (false)
#else #else
#define KJ_DLOG LOG #define KJ_DLOG LOG
#define KJ_DASSERT KJ_ASSERT #define KJ_DASSERT KJ_ASSERT
#define KJ_RECOVERABLE_DASSERT RECOVERABLE_ASSERT
#define KJ_DREQUIRE KJ_REQUIRE #define KJ_DREQUIRE KJ_REQUIRE
#define KJ_RECOVERABLE_DREQUIRE RECOVERABLE_REQUIRE
#endif #endif
template <typename... Params> template <typename... Params>
...@@ -292,72 +267,24 @@ void Log::log(const char* file, int line, Severity severity, const char* macroAr ...@@ -292,72 +267,24 @@ void Log::log(const char* file, int line, Severity severity, const char* macroAr
} }
template <typename... Params> template <typename... Params>
void Log::recoverableFault(const char* file, int line, Exception::Nature nature, Log::Fault::Fault(const char* file, int line, Exception::Nature nature, int errorNumber,
const char* condition, const char* macroArgs, Params&&... params) { const char* condition, const char* macroArgs, Params&&... params)
String argValues[sizeof...(Params)] = {str(params)...}; : exception(nullptr) {
recoverableFaultInternal(file, line, nature, condition, macroArgs,
arrayPtr(argValues, sizeof...(Params)));
}
template <typename... Params>
void Log::fatalFault(const char* file, int line, Exception::Nature nature,
const char* condition, const char* macroArgs, Params&&... params) {
String argValues[sizeof...(Params)] = {str(params)...}; String argValues[sizeof...(Params)] = {str(params)...};
fatalFaultInternal(file, line, nature, condition, macroArgs, init(file, line, nature, errorNumber, condition, macroArgs,
arrayPtr(argValues, sizeof...(Params))); arrayPtr(argValues, sizeof...(Params)));
} }
template <typename Call, typename... Params> template <typename Call>
bool Log::recoverableSyscall(Call&& call, const char* file, int line, const char* callText, Log::SyscallResult Log::syscall(Call&& call) {
const char* macroArgs, Params&&... params) { while (call() < 0) {
int result;
while ((result = call()) < 0) {
int errorNum = getOsErrorNumber(); int errorNum = getOsErrorNumber();
// getOsErrorNumber() returns -1 to indicate EINTR // getOsErrorNumber() returns -1 to indicate EINTR
if (errorNum != -1) { if (errorNum != -1) {
String argValues[sizeof...(Params)] = {str(params)...}; return SyscallResult(errorNum);
recoverableFailedSyscallInternal(file, line, callText, errorNum,
macroArgs, arrayPtr(argValues, sizeof...(Params)));
return false;
} }
} }
return true; return SyscallResult(0);
}
#ifndef __CDT_PARSER__ // Eclipse dislikes the late return spec.
template <typename Call, typename... Params>
auto Log::syscall(Call&& call, const char* file, int line, const char* callText,
const char* macroArgs, Params&&... params) -> decltype(call()) {
decltype(call()) result;
while ((result = call()) < 0) {
int errorNum = getOsErrorNumber();
// getOsErrorNumber() returns -1 to indicate EINTR
if (errorNum != -1) {
String argValues[sizeof...(Params)] = {str(params)...};
fatalFailedSyscallInternal(file, line, callText, errorNum,
macroArgs, arrayPtr(argValues, sizeof...(Params)));
}
}
return result;
}
#endif
template <typename... Params>
void Log::reportFailedRecoverableSyscall(
int errorNumber, const char* file, int line, const char* callText,
const char* macroArgs, Params&&... params) {
String argValues[sizeof...(Params)] = {str(params)...};
recoverableFailedSyscallInternal(file, line, callText, errorNumber, macroArgs,
arrayPtr(argValues, sizeof...(Params)));
}
template <typename... Params>
void Log::reportFailedSyscall(
int errorNumber, const char* file, int line, const char* callText,
const char* macroArgs, Params&&... params) {
String argValues[sizeof...(Params)] = {str(params)...};
fatalFailedSyscallInternal(file, line, callText, errorNumber, macroArgs,
arrayPtr(argValues, sizeof...(Params)));
} }
template <typename... Params> template <typename... Params>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment