Commit 42951751 authored by Kenton Varda's avatar Kenton Varda

Fix and test compiler's encoding of unions. Also update list pointer encoding to match docs.

parent eb8404a1
...@@ -252,6 +252,35 @@ TEST(Encoding, UnionLayout) { ...@@ -252,6 +252,35 @@ TEST(Encoding, UnionLayout) {
EXPECT_EQ(UnionState({0,0,0,4}, 384), initUnion(&TestUnion::Builder::setU3f0s64)); EXPECT_EQ(UnionState({0,0,0,4}, 384), initUnion(&TestUnion::Builder::setU3f0s64));
} }
TEST(Encoding, UnionDefault) {
MallocMessageBuilder builder;
TestUnionDefaults::Reader reader = builder.getRoot<TestUnionDefaults>().asReader();
{
auto field = reader.getS16s8s64s8Set();
EXPECT_EQ(TestUnion::Union0::U0F0S16, field.whichUnion0());
EXPECT_EQ(TestUnion::Union1::U1F0S8 , field.whichUnion1());
EXPECT_EQ(TestUnion::Union2::U2F0S64, field.whichUnion2());
EXPECT_EQ(TestUnion::Union3::U3F0S8 , field.whichUnion3());
EXPECT_EQ(321, field.getU0f0s16());
EXPECT_EQ(123, field.getU1f0s8());
EXPECT_EQ(12345678901234567ll, field.getU2f0s64());
EXPECT_EQ(55, field.getU3f0s8());
}
{
auto field = reader.getS0sps1s32Set();
EXPECT_EQ(TestUnion::Union0::U0F1S0 , field.whichUnion0());
EXPECT_EQ(TestUnion::Union1::U1F0SP , field.whichUnion1());
EXPECT_EQ(TestUnion::Union2::U2F0S1 , field.whichUnion2());
EXPECT_EQ(TestUnion::Union3::U3F0S32, field.whichUnion3());
EXPECT_EQ(Void::VOID, field.getU0f1s0());
EXPECT_EQ("foo", field.getU1f0sp());
EXPECT_EQ(true, field.getU2f0s1());
EXPECT_EQ(12345678, field.getU3f0s32());
}
}
} // namespace } // namespace
} // namespace internal } // namespace internal
} // namespace capnproto } // namespace capnproto
...@@ -140,10 +140,10 @@ struct WireReference { ...@@ -140,10 +140,10 @@ struct WireReference {
WireValue<uint32_t> elementSizeAndCount; WireValue<uint32_t> elementSizeAndCount;
CAPNPROTO_ALWAYS_INLINE(FieldSize elementSize() const) { CAPNPROTO_ALWAYS_INLINE(FieldSize elementSize() const) {
return static_cast<FieldSize>(elementSizeAndCount.get() >> 29); return static_cast<FieldSize>(elementSizeAndCount.get() & 7);
} }
CAPNPROTO_ALWAYS_INLINE(ElementCount elementCount() const) { CAPNPROTO_ALWAYS_INLINE(ElementCount elementCount() const) {
return (elementSizeAndCount.get() & 0x1fffffffu) * ELEMENTS; return (elementSizeAndCount.get() >> 3) * ELEMENTS;
} }
CAPNPROTO_ALWAYS_INLINE(WordCount inlineCompositeWordCount() const) { CAPNPROTO_ALWAYS_INLINE(WordCount inlineCompositeWordCount() const) {
return elementCount() * (1 * WORDS / ELEMENTS); return elementCount() * (1 * WORDS / ELEMENTS);
...@@ -152,14 +152,14 @@ struct WireReference { ...@@ -152,14 +152,14 @@ struct WireReference {
CAPNPROTO_ALWAYS_INLINE(void set(FieldSize es, ElementCount ec)) { CAPNPROTO_ALWAYS_INLINE(void set(FieldSize es, ElementCount ec)) {
CAPNPROTO_DEBUG_ASSERT(ec < (1 << 29) * ELEMENTS, CAPNPROTO_DEBUG_ASSERT(ec < (1 << 29) * ELEMENTS,
"Lists are limited to 2**29 elements."); "Lists are limited to 2**29 elements.");
elementSizeAndCount.set((static_cast<int>(es) << 29) | (ec / ELEMENTS)); elementSizeAndCount.set(((ec / ELEMENTS) << 3) | static_cast<int>(es));
} }
CAPNPROTO_ALWAYS_INLINE(void setInlineComposite(WordCount wc)) { CAPNPROTO_ALWAYS_INLINE(void setInlineComposite(WordCount wc)) {
CAPNPROTO_DEBUG_ASSERT(wc < (1 << 29) * WORDS, CAPNPROTO_DEBUG_ASSERT(wc < (1 << 29) * WORDS,
"Inline composite lists are limited to 2**29 words."); "Inline composite lists are limited to 2**29 words.");
elementSizeAndCount.set( elementSizeAndCount.set(((wc / WORDS) << 3) |
(static_cast<int>(FieldSize::INLINE_COMPOSITE) << 29) | (wc / WORDS)); static_cast<int>(FieldSize::INLINE_COMPOSITE));
} }
} listRef; } listRef;
......
...@@ -226,3 +226,10 @@ struct TestUnion { ...@@ -226,3 +226,10 @@ struct TestUnion {
u3f0s1 @46 in union3: Bool; u3f0s1 @46 in union3: Bool;
u2f0s1 @45 in union2: Bool; u2f0s1 @45 in union2: Bool;
} }
struct TestUnionDefaults {
s16s8s64s8Set @0 :TestUnion =
(u0f0s16 = 321, u1f0s8 = 123, u2f0s64 = 12345678901234567, u3f0s8 = 55);
s0sps1s32Set @1 :TestUnion =
(u0f1s0 = void, u1f0sp = "foo", u2f0s1 = true, u3f0s32 = 12345678);
}
...@@ -210,7 +210,7 @@ compileValue pos (StructType desc) (RecordFieldValue fields) = do ...@@ -210,7 +210,7 @@ compileValue pos (StructType desc) (RecordFieldValue fields) = do
-- Check for multiple assignments in the same union. -- Check for multiple assignments in the same union.
_ <- let _ <- let
dupes = findDupesBy (\(_, u) -> unionName u) dupes = findDupesBy (\(_, u) -> unionName u)
[(f, u) | (f@(FieldDesc {fieldUnion = Just u}), _) <- assignments] [(f, u) | (f@(FieldDesc {fieldUnion = Just (u, _)}), _) <- assignments]
errors = map dupUnionError dupes errors = map dupUnionError dupes
dupUnionError [] = error "empty group?" dupUnionError [] = error "empty group?"
dupUnionError dupFields@((_, u):_) = makeError pos (printf dupUnionError dupFields@((_, u):_) = makeError pos (printf
...@@ -322,7 +322,7 @@ requireNoDuplicateNames decls = Active () (loop (List.sort locatedNames)) where ...@@ -322,7 +322,7 @@ requireNoDuplicateNames decls = Active () (loop (List.sort locatedNames)) where
fieldInUnion name f = case fieldUnion f of fieldInUnion name f = case fieldUnion f of
Nothing -> False Nothing -> False
Just x -> unionName x == name Just (x, _) -> unionName x == name
requireNoMoreThanOneFieldNumberLessThan name pos num fields = Active () errors where requireNoMoreThanOneFieldNumberLessThan name pos num fields = Active () errors where
retroFields = [fieldName f | f <- fields, fieldNumber f < num] retroFields = [fieldName f | f <- fields, fieldNumber f < num]
...@@ -437,7 +437,7 @@ packField fieldDesc state unionState = ...@@ -437,7 +437,7 @@ packField fieldDesc state unionState =
Nothing -> let Nothing -> let
(offset, newState) = packValue (fieldSize $ fieldType fieldDesc) state (offset, newState) = packValue (fieldSize $ fieldType fieldDesc) state
in (offset, newState, unionState) in (offset, newState, unionState)
Just unionDesc -> let Just (unionDesc, _) -> let
n = unionNumber unionDesc n = unionNumber unionDesc
oldUnionPacking = fromMaybe initialUnionPackingState (Map.lookup n unionState) oldUnionPacking = fromMaybe initialUnionPackingState (Map.lookup n unionState)
(offset, newUnionPacking, newState) = (offset, newUnionPacking, newState) =
...@@ -577,9 +577,9 @@ compileDecl scope (StructDecl (Located _ name) decls) = ...@@ -577,9 +577,9 @@ compileDecl scope (StructDecl (Located _ name) decls) =
compileDecl (DescStruct parent) (UnionDecl (Located _ name) (Located numPos number) decls) = compileDecl (DescStruct parent) (UnionDecl (Located _ name) (Located numPos number) decls) =
CompiledMemberStatus name (feedback (\desc -> do CompiledMemberStatus name (feedback (\desc -> do
(_, _, options, statements) <- compileChildDecls desc decls (_, _, options, statements) <- compileChildDecls desc decls
let compareFieldNumbers a b = compare (fieldNumber a) (fieldNumber b) let fields = [f | f <- structFields parent, fieldInUnion name f]
fields = List.sortBy compareFieldNumbers orderedFieldNumbers = List.sort $ map fieldNumber fields
[f | f <- structFields parent, fieldInUnion name f] discriminantMap = Map.fromList $ zip orderedFieldNumbers [0..]
requireNoMoreThanOneFieldNumberLessThan name numPos number fields requireNoMoreThanOneFieldNumberLessThan name numPos number fields
return (let return (let
(tagOffset, tagPacking) = structFieldPackingMap parent ! number (tagOffset, tagPacking) = structFieldPackingMap parent ! number
...@@ -592,6 +592,7 @@ compileDecl (DescStruct parent) (UnionDecl (Located _ name) (Located numPos numb ...@@ -592,6 +592,7 @@ compileDecl (DescStruct parent) (UnionDecl (Located _ name) (Located numPos numb
, unionFields = fields , unionFields = fields
, unionOptions = options , unionOptions = options
, unionStatements = statements , unionStatements = statements
, unionFieldDiscriminantMap = discriminantMap
}))) })))
compileDecl _ (UnionDecl (Located pos name) _ _) = compileDecl _ (UnionDecl (Located pos name) _ _) =
CompiledMemberStatus name (makeError pos "Unions can only appear inside structs.") CompiledMemberStatus name (makeError pos "Unions can only appear inside structs.")
...@@ -605,7 +606,7 @@ compileDecl scope@(DescStruct parent) ...@@ -605,7 +606,7 @@ compileDecl scope@(DescStruct parent)
udesc <- maybeError (descMember n scope) p udesc <- maybeError (descMember n scope) p
(printf "No union '%s' defined in '%s'." n (structName parent)) (printf "No union '%s' defined in '%s'." n (structName parent))
case udesc of case udesc of
DescUnion d -> return (Just d) DescUnion d -> return (Just (d, unionFieldDiscriminantMap d ! number))
_ -> makeError p (printf "'%s' is not a union." n) _ -> makeError p (printf "'%s' is not a union." n)
typeDesc <- compileType scope typeExp typeDesc <- compileType scope typeExp
defaultDesc <- case defaultValue of defaultDesc <- case defaultValue of
......
...@@ -205,7 +205,10 @@ fieldContext parent desc = mkStrContext context where ...@@ -205,7 +205,10 @@ fieldContext parent desc = mkStrContext context where
context "fieldElementType" = context "fieldElementType" =
MuVariable $ cxxTypeString $ elementType $ fieldType desc MuVariable $ cxxTypeString $ elementType $ fieldType desc
context "fieldUnion" = case fieldUnion desc of context "fieldUnion" = case fieldUnion desc of
Just u -> MuList [unionContext context u] Just (u, _) -> MuList [unionContext context u]
Nothing -> muNull
context "fieldUnionDiscriminant" = case fieldUnion desc of
Just (_, n) -> MuVariable n
Nothing -> muNull Nothing -> muNull
context s = parent s context s = parent s
......
...@@ -130,7 +130,7 @@ compareErrors a b = compare (errorPos a) (errorPos b) ...@@ -130,7 +130,7 @@ compareErrors a b = compare (errorPos a) (errorPos b)
-- TODO: This is a fairly hacky way to make showErrorMessages' output not suck. We could do better -- TODO: This is a fairly hacky way to make showErrorMessages' output not suck. We could do better
-- by interpreting the error structure ourselves. -- by interpreting the error structure ourselves.
printError e = hPutStr stderr $ printf "%s:%d:%d: %s\n" f l c m' where printError e = hPutStr stderr $ printf "%s:%d:%d: error: %s\n" f l c m' where
pos = errorPos e pos = errorPos e
f = sourceName pos f = sourceName pos
l = sourceLine pos l = sourceLine pos
......
...@@ -321,9 +321,12 @@ data UnionDesc = UnionDesc ...@@ -321,9 +321,12 @@ data UnionDesc = UnionDesc
, unionNumber :: Integer , unionNumber :: Integer
, unionTagOffset :: Integer , unionTagOffset :: Integer
, unionTagPacking :: PackingState , unionTagPacking :: PackingState
, unionFields :: [FieldDesc] -- ordered by field number , unionFields :: [FieldDesc]
, unionOptions :: OptionMap , unionOptions :: OptionMap
, unionStatements :: [CompiledStatement] , unionStatements :: [CompiledStatement]
-- Maps field numbers to discriminants for all fields in the union.
, unionFieldDiscriminantMap :: Map.Map Integer Integer
} }
unionHasRetro desc = case unionFields desc of unionHasRetro desc = case unionFields desc of
...@@ -336,7 +339,7 @@ data FieldDesc = FieldDesc ...@@ -336,7 +339,7 @@ data FieldDesc = FieldDesc
, fieldNumber :: Integer , fieldNumber :: Integer
, fieldOffset :: Integer , fieldOffset :: Integer
, fieldPacking :: PackingState -- PackingState for the struct *if* this were the final field. , fieldPacking :: PackingState -- PackingState for the struct *if* this were the final field.
, fieldUnion :: Maybe UnionDesc , fieldUnion :: Maybe (UnionDesc, Integer) -- Integer is value of union discriminant.
, fieldType :: TypeDesc , fieldType :: TypeDesc
, fieldDefaultValue :: Maybe ValueDesc , fieldDefaultValue :: Maybe ValueDesc
, fieldOptions :: OptionMap , fieldOptions :: OptionMap
...@@ -408,7 +411,7 @@ descToCode indent (DescStruct desc) = printf "%sstruct %s%s" indent ...@@ -408,7 +411,7 @@ descToCode indent (DescStruct desc) = printf "%sstruct %s%s" indent
(blockCode indent (structStatements desc)) (blockCode indent (structStatements desc))
descToCode indent (DescField desc) = printf "%s%s@%d%s: %s%s; # %s\n" indent descToCode indent (DescField desc) = printf "%s%s@%d%s: %s%s; # %s\n" indent
(fieldName desc) (fieldNumber desc) (fieldName desc) (fieldNumber desc)
(case fieldUnion desc of { Nothing -> ""; Just u -> " in " ++ unionName u}) (case fieldUnion desc of { Nothing -> ""; Just (u, _) -> " in " ++ unionName u})
(typeName (DescStruct (fieldParent desc)) (fieldType desc)) (typeName (DescStruct (fieldParent desc)) (fieldType desc))
(case fieldDefaultValue desc of { Nothing -> ""; Just v -> " = " ++ valueString v; }) (case fieldDefaultValue desc of { Nothing -> ""; Just v -> " = " ++ valueString v; })
(case fieldSize $ fieldType desc of (case fieldSize $ fieldType desc of
......
...@@ -145,10 +145,10 @@ encodeStructReference desc offset = ...@@ -145,10 +145,10 @@ encodeStructReference desc offset =
encodeListReference elemSize@(SizeInlineComposite ds rc) elementCount offset = encodeListReference elemSize@(SizeInlineComposite ds rc) elementCount offset =
bytes (offset * 4 + listTag) 4 ++ bytes (offset * 4 + listTag) 4 ++
bytes (shiftL (fieldSizeEnum elemSize) 29 + elementCount * (ds + rc)) 4 bytes (fieldSizeEnum elemSize + shiftL (elementCount * (ds + rc)) 3) 4
encodeListReference elemSize elementCount offset = encodeListReference elemSize elementCount offset =
bytes (offset * 4 + listTag) 4 ++ bytes (offset * 4 + listTag) 4 ++
bytes (shiftL (fieldSizeEnum elemSize) 29 + elementCount) 4 bytes (fieldSizeEnum elemSize + shiftL elementCount 3) 4
fieldSizeEnum Size0 = 0 fieldSizeEnum Size0 = 0
fieldSizeEnum Size1 = 1 fieldSizeEnum Size1 = 1
...@@ -162,11 +162,6 @@ fieldSizeEnum (SizeInlineComposite _ _) = 7 ...@@ -162,11 +162,6 @@ fieldSizeEnum (SizeInlineComposite _ _) = 7
structTag = 0 structTag = 0
listTag = 1 listTag = 1
-- What is this union's default tag value? If there is a retroactive field, it is that field's
-- number, otherwise it is the union's number (meaning no field set).
unionDefault desc = UInt8Desc $ fromIntegral $
max (minimum $ map fieldNumber $ unionFields desc) (unionNumber desc)
-- childOffset = number of words between the last reference and the location where children will -- childOffset = number of words between the last reference and the location where children will
-- be allocated. -- be allocated.
encodeStruct desc assignments childOffset = (dataBytes, referenceBytes, children) where encodeStruct desc assignments childOffset = (dataBytes, referenceBytes, children) where
...@@ -174,9 +169,9 @@ encodeStruct desc assignments childOffset = (dataBytes, referenceBytes, children ...@@ -174,9 +169,9 @@ encodeStruct desc assignments childOffset = (dataBytes, referenceBytes, children
explicitValues = [(fieldOffset f, fieldType f, v, fieldDefaultValue f) | (f, v) <- assignments] explicitValues = [(fieldOffset f, fieldType f, v, fieldDefaultValue f) | (f, v) <- assignments]
-- Values of union tags. -- Values of union tags.
unionValues = [(unionTagOffset u, BuiltinType BuiltinUInt8, UInt8Desc $ fromIntegral n, unionValues = [(unionTagOffset u, BuiltinType BuiltinUInt16, UInt16Desc $ fromIntegral n,
Just $ unionDefault u) Nothing)
| (FieldDesc {fieldUnion = Just u, fieldNumber = n}, _) <- assignments] | (FieldDesc {fieldUnion = Just (u, n)}, _) <- assignments]
allValues = explicitValues ++ unionValues allValues = explicitValues ++ unionValues
allData = [ (o * sizeInBits (fieldValueSize v), t, v, d) allData = [ (o * sizeInBits (fieldValueSize v), t, v, d)
......
...@@ -55,7 +55,7 @@ struct {{structName}} { ...@@ -55,7 +55,7 @@ struct {{structName}} {
// {{unionDecl}} // {{unionDecl}}
enum class {{unionTitleCase}}: uint16_t { enum class {{unionTitleCase}}: uint16_t {
{{#unionFields}} {{#unionFields}}
{{fieldUpperCase}}, {{fieldUpperCase}} = {{fieldUnionDiscriminant}},
{{/unionFields}} {{/unionFields}}
}; };
{{/structUnions}} {{/structUnions}}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment