Commit 842e0d59 authored by Kenton Varda's avatar Kenton Varda

Fix bug with JSON discriminated unions with non-flattened members.

The code was init()ing the union field when the discriminant was seen, but this only works if the field type is a struct or group. Instead we store the field schema and initialize it when we see the value later.
parent 263b72ed
...@@ -844,11 +844,11 @@ R"({ "names-can_contain!anything Really": "foo", ...@@ -844,11 +844,11 @@ R"({ "names-can_contain!anything Really": "foo",
"testBase64": "ZnJlZA==", "testBase64": "ZnJlZA==",
"testHex": "706c756768", "testHex": "706c756768",
"bUnion": "renamed-bar", "bUnion": "renamed-bar",
"bValue": {"hi": 678} })"_kj; "bValue": 678 })"_kj;
static constexpr kj::StringPtr GOLDEN_ANNOTATED_REVERSE = static constexpr kj::StringPtr GOLDEN_ANNOTATED_REVERSE =
R"({ R"({
"bValue": {"hi": 678}, "bValue": 678,
"bUnion": "renamed-bar", "bUnion": "renamed-bar",
"testHex": "706c756768", "testHex": "706c756768",
"testBase64": "ZnJlZA==", "testBase64": "ZnJlZA==",
...@@ -938,7 +938,7 @@ KJ_TEST("rename fields") { ...@@ -938,7 +938,7 @@ KJ_TEST("rename fields") {
root.setTestBase64("fred"_kj.asBytes()); root.setTestBase64("fred"_kj.asBytes());
root.setTestHex("plugh"_kj.asBytes()); root.setTestHex("plugh"_kj.asBytes());
root.getBUnion().initBar().setHi(678); root.getBUnion().setBar(678);
auto encoded = json.encode(root.asReader()); auto encoded = json.encode(root.asReader());
KJ_EXPECT(encoded == GOLDEN_ANNOTATED, encoded); KJ_EXPECT(encoded == GOLDEN_ANNOTATED, encoded);
......
...@@ -80,7 +80,7 @@ struct TestJsonAnnotations { ...@@ -80,7 +80,7 @@ struct TestJsonAnnotations {
bUnion :union $Json.flatten() $Json.discriminator(valueName = "bValue") { bUnion :union $Json.flatten() $Json.discriminator(valueName = "bValue") {
foo @20 :Text; foo @20 :Text;
bar :group $Json.name("renamed-bar") { hi @21 :UInt32; } bar @21 :UInt32 $Json.name("renamed-bar");
} }
} }
......
...@@ -25,13 +25,13 @@ ...@@ -25,13 +25,13 @@
#include <errno.h> // for strtod errors #include <errno.h> // for strtod errors
#include <unordered_map> #include <unordered_map>
#include <map> #include <map>
#include <set>
#include <capnp/orphan.h> #include <capnp/orphan.h>
#include <kj/debug.h> #include <kj/debug.h>
#include <kj/function.h> #include <kj/function.h>
#include <kj/vector.h> #include <kj/vector.h>
#include <kj/one-of.h> #include <kj/one-of.h>
#include <kj/encoding.h> #include <kj/encoding.h>
#include <kj/map.h>
namespace capnp { namespace capnp {
...@@ -1126,7 +1126,7 @@ public: ...@@ -1126,7 +1126,7 @@ public:
void decode(const JsonCodec& codec, JsonValue::Reader input, void decode(const JsonCodec& codec, JsonValue::Reader input,
DynamicStruct::Builder output) const override { DynamicStruct::Builder output) const override {
KJ_REQUIRE(input.isObject()); KJ_REQUIRE(input.isObject());
std::set<const void*> unionsSeen; kj::HashMap<const void*, StructSchema::Field> unionsSeen;
kj::Vector<JsonValue::Field::Reader> retries; kj::Vector<JsonValue::Field::Reader> retries;
for (auto field: input.getObject()) { for (auto field: input.getObject()) {
if (!decodeField(codec, field.getName(), field.getValue(), output, unionsSeen)) { if (!decodeField(codec, field.getName(), field.getValue(), output, unionsSeen)) {
...@@ -1261,7 +1261,8 @@ private: ...@@ -1261,7 +1261,8 @@ private:
} }
bool decodeField(const JsonCodec& codec, kj::StringPtr name, JsonValue::Reader value, bool decodeField(const JsonCodec& codec, kj::StringPtr name, JsonValue::Reader value,
DynamicStruct::Builder output, std::set<const void*>& unionsSeen) const { DynamicStruct::Builder output,
kj::HashMap<const void*, StructSchema::Field>& unionsSeen) const {
KJ_ASSERT(output.getSchema() == schema); KJ_ASSERT(output.getSchema() == schema);
auto iter = fieldsByName.find(name); auto iter = fieldsByName.find(name);
...@@ -1287,36 +1288,36 @@ private: ...@@ -1287,36 +1288,36 @@ private:
// Mark that we've seen a union tag for this struct. // Mark that we've seen a union tag for this struct.
const void* ptr = getUnionInstanceIdentifier(output); const void* ptr = getUnionInstanceIdentifier(output);
KJ_REQUIRE(unionsSeen.insert(ptr).second, "Duplicate field name.");
auto iter = unionTagValues.find(value.getString()); auto iter = unionTagValues.find(value.getString());
if (iter != unionTagValues.end()) { if (iter != unionTagValues.end()) {
output.init(iter->second); unionsSeen.insert(ptr, iter->second);
} }
return true; return true;
} }
case FieldNameInfo::FLATTENED_FROM_UNION: { case FieldNameInfo::FLATTENED_FROM_UNION: {
const void* ptr = getUnionInstanceIdentifier(output); const void* ptr = getUnionInstanceIdentifier(output);
if (unionsSeen.count(ptr) == 0) { KJ_IF_MAYBE(variant, unionsSeen.find(ptr)) {
bool alreadyInitialized = output.which()
.map([&](auto f) { return f == *variant; })
.orDefault(false);
auto child = alreadyInitialized ? output.get(*variant) : output.init(*variant);
return KJ_ASSERT_NONNULL(fields[variant->getIndex()].flattenHandler)
.decodeField(codec, name.slice(info.prefixLength), value,
child.as<DynamicStruct>(), unionsSeen);
} else {
// We haven't seen the union tag yet, so we can't parse this field yet. Try again later. // We haven't seen the union tag yet, so we can't parse this field yet. Try again later.
return false; return false;
} }
auto variant = KJ_ASSERT_NONNULL(output.which());
return KJ_ASSERT_NONNULL(fields[variant.getIndex()].flattenHandler)
.decodeField(codec, name.slice(info.prefixLength), value,
output.get(variant).as<DynamicStruct>(), unionsSeen);
} }
case FieldNameInfo::UNION_VALUE: { case FieldNameInfo::UNION_VALUE: {
const void* ptr = getUnionInstanceIdentifier(output); const void* ptr = getUnionInstanceIdentifier(output);
if (unionsSeen.count(ptr) == 0) { KJ_IF_MAYBE(variant, unionsSeen.find(ptr)) {
codec.decodeField(*variant, value, Orphanage::getForMessageContaining(output), output);
return true;
} else {
// We haven't seen the union tag yet, so we can't parse this field yet. Try again later. // We haven't seen the union tag yet, so we can't parse this field yet. Try again later.
return false; return false;
} }
auto variant = KJ_ASSERT_NONNULL(output.which());
codec.decodeField(variant, value, Orphanage::getForMessageContaining(output), output);
return true;
} }
} }
......
...@@ -17,4 +17,4 @@ ...@@ -17,4 +17,4 @@
"testBase64": "ZnJlZA==", "testBase64": "ZnJlZA==",
"testHex": "706c756768", "testHex": "706c756768",
"bUnion": "renamed-bar", "bUnion": "renamed-bar",
"bValue": {"hi": 678} } "bValue": 678 }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment