CxxGenerator.hs 25.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
-- Copyright (c) 2013, Kenton Varda <temporal@gmail.com>
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
-- 1. Redistributions of source code must retain the above copyright notice, this
--    list of conditions and the following disclaimer.
-- 2. Redistributions in binary form must reproduce the above copyright notice,
--    this list of conditions and the following disclaimer in the documentation
--    and/or other materials provided with the distribution.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-- ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
-- ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-- (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-- ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-- SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

{-# LANGUAGE TemplateHaskell #-}

26
module CxxGenerator(generateCxx) where
27

28
import qualified Data.ByteString.UTF8 as ByteStringUTF8
29
import Data.FileEmbed(embedFile)
30
import Data.Word(Word8)
31
import qualified Data.Digest.MD5 as MD5
32
import qualified Data.Map as Map
33
import qualified Data.Set as Set
34
import qualified Data.List as List
35
import Data.Maybe(catMaybes, mapMaybe)
36
import Data.Binary.IEEE754(floatToWord, doubleToWord)
37
import Data.Map((!))
38
import Data.Function(on)
39 40 41
import Text.Printf(printf)
import Text.Hastache
import Text.Hastache.Context
42 43
import qualified Codec.Binary.UTF8.String as UTF8
import System.FilePath(takeBaseName)
44 45 46

import Semantics
import Util
47
import WireFormat
48 49 50 51 52

-- MuNothing isn't considered a false value for the purpose of {{#variable}} expansion.  Use this
-- instead.
muNull = MuBool False;

53 54 55 56
-- There is no way to make a MuType from a singular MuContext, i.e. for optional sub-contexts.
-- Using a single-element list has the same effect, though.
muJust c = MuList [c]

57
namespaceAnnotationId = 0xb9c6f99ebf805f2c
58 59 60 61 62 63 64

fileNamespace desc = fmap testAnnotation $ Map.lookup namespaceAnnotationId $ fileAnnotations desc

testAnnotation (_, TextDesc x) = x
testAnnotation (desc, _) =
    error "Annotation was supposed to be text, but wasn't: " ++ annotationName desc

65 66 67 68 69
fullName desc = scopePrefix (descParent desc) ++ descName desc

scopePrefix (DescFile _) = ""
scopePrefix desc = fullName desc ++ "::"

70
globalName (DescFile desc) = maybe " " (" ::" ++) $ fileNamespace desc
71 72
globalName desc = globalName (descParent desc) ++ "::" ++ descName desc

73
-- Flatten the descriptor tree in pre-order, returning struct, union, and interface descriptors
74
-- only.
75 76 77 78
flattenTypes :: [Desc] -> [Desc]
flattenTypes [] = []
flattenTypes (d@(DescStruct s):rest) = d:(flattenTypes children ++ flattenTypes rest) where
    children = catMaybes $ Map.elems $ structMemberMap s
79 80
flattenTypes (d@(DescUnion u):rest) = d:(flattenTypes children ++ flattenTypes rest) where
    children = catMaybes $ Map.elems $ unionMemberMap u
81 82
flattenTypes (d@(DescInterface i):rest) = d:(flattenTypes children ++ flattenTypes rest) where
    children = catMaybes $ Map.elems $ interfaceMemberMap i
83
flattenTypes (d@(DescEnum _):rest) = d:flattenTypes rest
84 85
flattenTypes (_:rest) = flattenTypes rest

86 87 88 89
hashString :: String -> String
hashString str =
    concatMap (printf "%02x" . fromEnum) $
    MD5.hash $
90
    UTF8.encode str
91

92
isPrimitive (BuiltinType BuiltinObject) = False
Kenton Varda's avatar
Kenton Varda committed
93
isPrimitive t@(BuiltinType _) = not $ isBlob t
94 95
isPrimitive (EnumType _) = True
isPrimitive (StructType _) = False
Kenton Varda's avatar
Kenton Varda committed
96
isPrimitive (InlineStructType _) = False
97 98
isPrimitive (InterfaceType _) = False
isPrimitive (ListType _) = False
99
isPrimitive (InlineListType _ _) = False
100
isPrimitive (InlineDataType _) = False
101

Kenton Varda's avatar
Kenton Varda committed
102 103
isBlob (BuiltinType BuiltinText) = True
isBlob (BuiltinType BuiltinData) = True
104
isBlob (InlineDataType _) = True
Kenton Varda's avatar
Kenton Varda committed
105 106
isBlob _ = False

107 108 109
isInlineBlob (InlineDataType _) = True
isInlineBlob _ = False

110
isStruct (StructType _) = True
Kenton Varda's avatar
Kenton Varda committed
111
isStruct (InlineStructType _) = True
112 113
isStruct _ = False

Kenton Varda's avatar
Kenton Varda committed
114 115 116
isInlineStruct (InlineStructType _) = True
isInlineStruct _ = False

117
isList (ListType _) = True
118
isList (InlineListType _ _) = True
119 120
isList _ = False

Kenton Varda's avatar
Kenton Varda committed
121
isNonStructList (ListType t) = not $ isStruct t
122
isNonStructList (InlineListType t _) = not $ isStruct t
Kenton Varda's avatar
Kenton Varda committed
123
isNonStructList _ = False
124

125
isPrimitiveList (ListType t) = isPrimitive t
126
isPrimitiveList (InlineListType t _) = isPrimitive t
127 128
isPrimitiveList _ = False

129 130 131 132 133 134 135 136 137 138 139 140
isPointerElement (InlineDataType _) = False
isPointerElement t = not (isPrimitive t || isStruct t || isInlineList t)

isPointerList (ListType t) = isPointerElement t
isPointerList (InlineListType t _) = isPointerElement t
isPointerList _ = False

isInlineBlobList (ListType t) = isInlineBlob t
isInlineBlobList _ = False

isStructList (ListType t@(InlineListType _ _)) = isStructList t
isStructList (InlineListType t@(InlineListType _ _) _) = isStructList t
141
isStructList (ListType t) = isStruct t
142
isStructList (InlineListType t _) = isStruct t
143 144
isStructList _ = False

145 146 147
isInlineList (InlineListType _ _) = True
isInlineList _ = False

148 149 150
isGenericObject (BuiltinType BuiltinObject) = True
isGenericObject _ = False

Kenton Varda's avatar
Kenton Varda committed
151 152
blobTypeString (BuiltinType BuiltinText) = "Text"
blobTypeString (BuiltinType BuiltinData) = "Data"
153 154 155
blobTypeString (InlineDataType _) = "Data"
blobTypeString (ListType t) = blobTypeString t
blobTypeString (InlineListType t _) = blobTypeString t
Kenton Varda's avatar
Kenton Varda committed
156 157
blobTypeString _ = error "Not a blob."

158 159 160 161 162 163 164 165 166
inlineMultiplier (InlineListType t s) = s * inlineMultiplier t
inlineMultiplier (InlineDataType s) = s
inlineMultiplier _ = 1

listInlineMultiplierString (ListType t) = case inlineMultiplier t of
    1 -> ""
    s -> " * " ++ show s
listInlineMultiplierString _ = error "Not a list."

167
cxxTypeString (BuiltinType BuiltinVoid) = " ::capnproto::Void"
168
cxxTypeString (BuiltinType BuiltinBool) = "bool"
Kenton Varda's avatar
Kenton Varda committed
169 170 171 172 173 174 175 176
cxxTypeString (BuiltinType BuiltinInt8) = " ::int8_t"
cxxTypeString (BuiltinType BuiltinInt16) = " ::int16_t"
cxxTypeString (BuiltinType BuiltinInt32) = " ::int32_t"
cxxTypeString (BuiltinType BuiltinInt64) = " ::int64_t"
cxxTypeString (BuiltinType BuiltinUInt8) = " ::uint8_t"
cxxTypeString (BuiltinType BuiltinUInt16) = " ::uint16_t"
cxxTypeString (BuiltinType BuiltinUInt32) = " ::uint32_t"
cxxTypeString (BuiltinType BuiltinUInt64) = " ::uint64_t"
177 178
cxxTypeString (BuiltinType BuiltinFloat32) = "float"
cxxTypeString (BuiltinType BuiltinFloat64) = "double"
Kenton Varda's avatar
Kenton Varda committed
179 180
cxxTypeString (BuiltinType BuiltinText) = " ::capnproto::Text"
cxxTypeString (BuiltinType BuiltinData) = " ::capnproto::Data"
181
cxxTypeString (BuiltinType BuiltinObject) = " ::capnproto::Object"
182 183
cxxTypeString (EnumType desc) = globalName $ DescEnum desc
cxxTypeString (StructType desc) = globalName $ DescStruct desc
Kenton Varda's avatar
Kenton Varda committed
184
cxxTypeString (InlineStructType desc) = globalName $ DescStruct desc
185
cxxTypeString (InterfaceType desc) = globalName $ DescInterface desc
Kenton Varda's avatar
Kenton Varda committed
186
cxxTypeString (ListType t) = concat [" ::capnproto::List<", cxxTypeString t, ">"]
187 188 189 190
cxxTypeString (InlineListType t s) =
    concat [" ::capnproto::InlineList<", cxxTypeString t, ", ", show s, ">"]
cxxTypeString (InlineDataType s) =
    concat [" ::capnproto::InlineData<", show s, ">"]
191

Kenton Varda's avatar
Kenton Varda committed
192 193 194 195 196 197
cxxFieldSizeString SizeVoid = "VOID";
cxxFieldSizeString (SizeData Size1) = "BIT";
cxxFieldSizeString (SizeData Size8) = "BYTE";
cxxFieldSizeString (SizeData Size16) = "TWO_BYTES";
cxxFieldSizeString (SizeData Size32) = "FOUR_BYTES";
cxxFieldSizeString (SizeData Size64) = "EIGHT_BYTES";
198
cxxFieldSizeString SizePointer = "POINTER";
199 200
cxxFieldSizeString (SizeInlineComposite _ _) = "INLINE_COMPOSITE";

Kenton Varda's avatar
Kenton Varda committed
201 202 203 204
fieldOffsetInteger VoidOffset = "0"
fieldOffsetInteger (DataOffset _ o) = show o
fieldOffsetInteger (PointerOffset o) = show o
fieldOffsetInteger (InlineCompositeOffset d p ds ps) = let
205 206 207 208 209
    byteSize = div (dataSectionBits ds) 8
    byteOffset = case ds of
        DataSectionWords _ -> d * 8
        _ -> d * byteSize
    in printf "%d * ::capnproto::BYTES, %d * ::capnproto::BYTES, \
210
              \%d * ::capnproto::POINTERS, %d * ::capnproto::POINTERS" byteOffset byteSize p ps
Kenton Varda's avatar
Kenton Varda committed
211

212 213 214 215 216 217 218 219 220 221 222 223
isDefaultZero VoidDesc = True
isDefaultZero (BoolDesc    b) = not b
isDefaultZero (Int8Desc    i) = i == 0
isDefaultZero (Int16Desc   i) = i == 0
isDefaultZero (Int32Desc   i) = i == 0
isDefaultZero (Int64Desc   i) = i == 0
isDefaultZero (UInt8Desc   i) = i == 0
isDefaultZero (UInt16Desc  i) = i == 0
isDefaultZero (UInt32Desc  i) = i == 0
isDefaultZero (UInt64Desc  i) = i == 0
isDefaultZero (Float32Desc x) = x == 0
isDefaultZero (Float64Desc x) = x == 0
224
isDefaultZero (EnumerantValueDesc v) = enumerantNumber v == 0
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
isDefaultZero (TextDesc _) = error "Can't call isDefaultZero on aggregate types."
isDefaultZero (DataDesc _) = error "Can't call isDefaultZero on aggregate types."
isDefaultZero (StructValueDesc _) = error "Can't call isDefaultZero on aggregate types."
isDefaultZero (ListDesc _) = error "Can't call isDefaultZero on aggregate types."

defaultMask VoidDesc = "0"
defaultMask (BoolDesc    b) = if b then "true" else "false"
defaultMask (Int8Desc    i) = show i
defaultMask (Int16Desc   i) = show i
defaultMask (Int32Desc   i) = show i
defaultMask (Int64Desc   i) = show i ++ "ll"
defaultMask (UInt8Desc   i) = show i
defaultMask (UInt16Desc  i) = show i
defaultMask (UInt32Desc  i) = show i ++ "u"
defaultMask (UInt64Desc  i) = show i ++ "llu"
defaultMask (Float32Desc x) = show (floatToWord x) ++ "u"
defaultMask (Float64Desc x) = show (doubleToWord x) ++ "ul"
242
defaultMask (EnumerantValueDesc v) = show (enumerantNumber v)
243 244 245 246
defaultMask (TextDesc _) = error "Can't call defaultMask on aggregate types."
defaultMask (DataDesc _) = error "Can't call defaultMask on aggregate types."
defaultMask (StructValueDesc _) = error "Can't call defaultMask on aggregate types."
defaultMask (ListDesc _) = error "Can't call defaultMask on aggregate types."
247 248 249 250 251 252

defaultValueBytes _ (TextDesc s) = Just (UTF8.encode s ++ [0])
defaultValueBytes _ (DataDesc d) = Just d
defaultValueBytes t v@(StructValueDesc _) = Just $ encodeMessage t v
defaultValueBytes t v@(ListDesc _) = Just $ encodeMessage t v
defaultValueBytes _ _ = Nothing
253 254

elementType (ListType t) = t
255
elementType (InlineListType t _) = t
256 257
elementType _ = error "Called elementType on non-list."

258 259 260 261
inlineElementType (ListType t@(InlineListType _ _)) = inlineElementType t
inlineElementType (InlineListType t@(InlineListType _ _) _) = inlineElementType t
inlineElementType t = elementType t

262 263 264
repeatedlyTake _ [] = []
repeatedlyTake n l = take n l : repeatedlyTake n (drop n l)

265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
typeDependencies (StructType s) = [structId s]
typeDependencies (EnumType e) = [enumId e]
typeDependencies (InterfaceType i) = [interfaceId i]
typeDependencies (ListType t) = typeDependencies t
typeDependencies _ = []

paramDependencies d = typeDependencies $ paramType d

descDependencies (DescStruct d) = concatMap descDependencies $ structMembers d
descDependencies (DescUnion d) = concatMap descDependencies $ unionMembers d
descDependencies (DescField d) = typeDependencies $ fieldType d
descDependencies (DescInterface d) = concatMap descDependencies $ interfaceMembers d
descDependencies (DescMethod d) =
    concat $ typeDependencies (methodReturnType d) : map paramDependencies (methodParams d)
descDependencies _ = []

281 282 283 284 285 286 287 288 289
memberIndexes :: Int -> [(Int, Int)]
memberIndexes unionIndex = zip (repeat unionIndex) [0..]

memberTable (DescStruct desc) = let
    -- Fields and unions of the struct.
    topMembers = zip (memberIndexes 0) $ mapMaybe memberName
               $ List.sortBy (compare `on` ordinal) $ structMembers desc

    -- Fields of each union.
290 291
    innerMembers = catMaybes $ zipWith indexedUnionMembers [1..]
                 $ List.sortBy (compare `on` ordinal) $ structMembers desc
292 293 294 295 296 297 298 299 300

    ordinal (DescField f) = fieldNumber f
    ordinal (DescUnion u) = unionNumber u
    ordinal _ = 65536  -- doesn't really matter what this is; will be filtered out later

    memberName (DescField f) = Just $ fieldName f
    memberName (DescUnion u) = Just $ unionName u
    memberName _ = Nothing

301 302 303 304
    indexedUnionMembers i (DescUnion u) =
        Just $ zip (memberIndexes i) $ mapMaybe memberName $
            List.sortBy (compare `on` ordinal) $ unionMembers u
    indexedUnionMembers _ _ = Nothing
305 306 307 308 309 310 311 312 313

    in concat $ topMembers : innerMembers

memberTable (DescEnum desc) = zip (memberIndexes 0) $ map enumerantName
    $ List.sortBy (compare `on` enumerantNumber) $ enumerants desc
memberTable (DescInterface desc) = zip (memberIndexes 0) $ map methodName
    $ List.sortBy (compare `on` methodNumber) $ interfaceMethods desc
memberTable _ = []

314 315 316 317 318
outerFileContext schemaNodes = fileContext where
    schemaDepContext parent i = mkStrContext context where
        context "dependencyId" = MuVariable (printf "%016x" i :: String)
        context s = parent s

319 320 321 322 323
    schemaMemberByNameContext parent (ui, mi) = mkStrContext context where
        context "memberUnionIndex" = MuVariable ui
        context "memberIndex" = MuVariable mi
        context s = parent s

324 325 326 327 328 329 330
    schemaContext parent desc = mkStrContext context where
        node = schemaNodes ! descId desc

        codeLines = map (delimit ", ") $ repeatedlyTake 8 $ map (printf "%3d") node

        depIds = map head $ List.group $ List.sort $ descDependencies desc

331 332 333
        membersByName = map fst $ List.sortBy (compare `on` memberByNameKey) $ memberTable desc
        memberByNameKey ((unionIndex, _), name) = (unionIndex, name)

334 335 336
        context "schemaWordCount" = MuVariable $ div (length node + 7) 8
        context "schemaBytes" = MuVariable $ delimit ",\n    " codeLines
        context "schemaId" = MuVariable (printf "%016x" (descId desc) :: String)
337
        context "schemaDependencyCount" = MuVariable $ length depIds
338 339
        context "schemaDependencies" =
            MuList $ map (schemaDepContext context) depIds
340 341 342
        context "schemaMemberCount" = MuVariable $ length membersByName
        context "schemaMembersByName" =
            MuList $ map (schemaMemberByNameContext context) membersByName
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
        context s = parent s

    enumerantContext parent desc = mkStrContext context where
        context "enumerantName" = MuVariable $ toUpperCaseWithUnderscores $ enumerantName desc
        context "enumerantNumber" = MuVariable $ enumerantNumber desc
        context s = parent s

    enumContext parent desc = mkStrContext context where
        context "enumName" = MuVariable $ enumName desc
        context "enumId" = MuVariable (printf "%016x" (enumId desc) ::String)
        context "enumerants" = MuList $ map (enumerantContext context) $ enumerants desc
        context s = parent s

    defaultBytesContext :: Monad m => (String -> MuType m) -> TypeDesc -> [Word8] -> MuContext m
    defaultBytesContext parent t bytes = mkStrContext context where
        codeLines = map (delimit ", ") $ repeatedlyTake 8 $ map (printf "%3d") bytes
        context "defaultByteList" = MuVariable $ delimit ",\n    " codeLines
        context "defaultWordCount" = MuVariable $ div (length bytes + 7) 8
        context "defaultBlobSize" = case t of
            BuiltinType BuiltinText -> MuVariable (length bytes - 1)  -- Don't include NUL terminator.
            BuiltinType BuiltinData -> MuVariable (length bytes)
            _ -> error "defaultBlobSize used on non-blob."
        context s = parent s

    descDecl desc = head $ lines $ descToCode "" desc

    fieldContext parent desc = mkStrContext context where
        context "fieldName" = MuVariable $ fieldName desc
        context "fieldDecl" = MuVariable $ descDecl $ DescField desc
        context "fieldTitleCase" = MuVariable $ toTitleCase $ fieldName desc
        context "fieldUpperCase" = MuVariable $ toUpperCaseWithUnderscores $ fieldName desc
        context "fieldIsPrimitive" = MuBool $ isPrimitive $ fieldType desc

        context "fieldIsListOrBlob" = MuBool $ isBlob (fieldType desc) || isList (fieldType desc)

        context "fieldIsBlob" = MuBool $ isBlob $ fieldType desc
        context "fieldIsInlineBlob" = MuBool $ isInlineBlob $ fieldType desc
        context "fieldIsStruct" = MuBool $ isStruct $ fieldType desc
        context "fieldIsInlineStruct" = MuBool $ isInlineStruct $ fieldType desc
        context "fieldIsList" = MuBool $ isList $ fieldType desc
        context "fieldIsNonStructList" = MuBool $ isNonStructList $ fieldType desc
        context "fieldIsPrimitiveList" = MuBool $ isPrimitiveList $ fieldType desc
        context "fieldIsPointerList" = MuBool $ isPointerList $ fieldType desc
        context "fieldIsInlineBlobList" = MuBool $ isInlineBlobList $ fieldType desc
        context "fieldIsStructList" = MuBool $ isStructList $ fieldType desc
        context "fieldIsInlineList" = MuBool $ isInlineList $ fieldType desc
        context "fieldIsGenericObject" = MuBool $ isGenericObject $ fieldType desc
        context "fieldDefaultBytes" =
            case fieldDefaultValue desc >>= defaultValueBytes (fieldType desc) of
                Just v -> muJust $ defaultBytesContext context (fieldType desc) v
                Nothing -> muNull
        context "fieldType" = MuVariable $ cxxTypeString $ fieldType desc
        context "fieldBlobType" = MuVariable $ blobTypeString $ fieldType desc
        context "fieldOffset" = MuVariable $ fieldOffsetInteger $ fieldOffset desc
        context "fieldInlineListSize" = case fieldType desc of
            InlineListType _ n -> MuVariable n
            InlineDataType n -> MuVariable n
            _ -> muNull
        context "fieldInlineDataOffset" = case fieldOffset desc of
            InlineCompositeOffset off _ size _ ->
                MuVariable (off * div (dataSizeInBits (dataSectionAlignment size)) 8)
            _ -> muNull
        context "fieldInlineDataSize" = case fieldOffset desc of
            InlineCompositeOffset _ _ size _ ->
                MuVariable $ div (dataSectionBits size) 8
            _ -> muNull
        context "fieldInlinePointerOffset" = case fieldOffset desc of
            InlineCompositeOffset _ off _ _ -> MuVariable off
            _ -> muNull
        context "fieldInlinePointerSize" = case fieldOffset desc of
            InlineCompositeOffset _ _ _ size -> MuVariable size
            _ -> muNull
        context "fieldInlineMultiplier" = MuVariable $ listInlineMultiplierString $ fieldType desc
        context "fieldDefaultMask" = case fieldDefaultValue desc of
            Nothing -> MuVariable ""
            Just v -> MuVariable (if isDefaultZero v then "" else ", " ++ defaultMask v)
        context "fieldElementSize" =
            MuVariable $ cxxFieldSizeString $ fieldSize $ inlineElementType $ fieldType desc
        context "fieldElementType" =
            MuVariable $ cxxTypeString $ elementType $ fieldType desc
        context "fieldElementReaderType" = MuVariable readerString where
            readerString = if isPrimitiveList $ fieldType desc
                then tString
                else tString ++ "::Reader"
            tString = cxxTypeString $ elementType $ fieldType desc
        context "fieldInlineElementType" =
            MuVariable $ cxxTypeString $ inlineElementType $ fieldType desc
        context "fieldUnion" = case fieldUnion desc of
            Just (u, _) -> muJust $ unionContext context u
432
            Nothing -> muNull
433 434 435 436 437 438 439 440 441 442
        context "fieldUnionDiscriminant" = case fieldUnion desc of
            Just (_, n) -> MuVariable n
            Nothing -> muNull
        context "fieldSetterDefault" = case fieldType desc of
            BuiltinType BuiltinVoid -> MuVariable " = ::capnproto::Void::VOID"
            _ -> MuVariable ""
        context s = parent s

    unionContext parent desc = mkStrContext context where
        titleCase = toTitleCase $ unionName desc
Kenton Varda's avatar
Kenton Varda committed
443 444
        
        unionIndex = Map.findIndex (unionNumber desc) $ structMembersByNumber $ unionParent desc
445 446 447 448 449 450 451 452 453 454 455 456 457 458

        context "typeStruct" = MuBool False
        context "typeUnion" = MuBool True
        context "typeName" = MuVariable titleCase
        context "typeFullName" = context "unionFullName"
        context "typeFields" = context "unionFields"

        context "unionName" = MuVariable $ unionName desc
        context "unionFullName" = MuVariable $ fullName (DescStruct $ unionParent desc) ++
                                 "::" ++ titleCase
        context "unionDecl" = MuVariable $ descDecl $ DescUnion desc
        context "unionTitleCase" = MuVariable titleCase
        context "unionTagOffset" = MuVariable $ unionTagOffset desc
        context "unionFields" = MuList $ map (fieldContext context) $ unionFields desc
Kenton Varda's avatar
Kenton Varda committed
459
        context "unionIndex" = MuVariable unionIndex
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
        context s = parent s

    childContext parent name = mkStrContext context where
        context "nestedName" = MuVariable name
        context s = parent s

    structContext parent desc = mkStrContext context where
        context "typeStruct" = MuBool True
        context "typeUnion" = MuBool False
        context "typeName" = context "structName"
        context "typeFullName" = context "structFullName"
        context "typeFields" = context "structFields"

        context "structName" = MuVariable $ structName desc
        context "structId" = MuVariable (printf "%016x" (structId desc) ::String)
        context "structFullName" = MuVariable $ fullName (DescStruct desc)
        context "structFields" = MuList $ map (fieldContext context) $ structFields desc
        context "structUnions" = MuList $ map (unionContext context) $ structUnions desc
        context "structDataSize" = MuVariable $ dataSectionWordSize $ structDataSize desc
479
        context "structPointerCount" = MuVariable $ structPointerCount desc
480 481 482 483 484 485 486
        context "structPreferredListEncoding" = case (structDataSize desc, structPointerCount desc) of
            (DataSectionWords 0, 0) -> MuVariable "VOID"
            (DataSection1, 0) -> MuVariable "BIT"
            (DataSection8, 0) -> MuVariable "BYTE"
            (DataSection16, 0) -> MuVariable "TWO_BYTES"
            (DataSection32, 0) -> MuVariable "FOUR_BYTES"
            (DataSectionWords 1, 0) -> MuVariable "EIGHT_BYTES"
487
            (DataSectionWords 0, 1) -> MuVariable "POINTER"
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
            _ -> MuVariable "INLINE_COMPOSITE"
        context "structNestedEnums" =
            MuList $ map (enumContext context) [m | DescEnum m <- structMembers desc]
        context "structNestedStructs" =
            MuList $ map (childContext context . structName) [m | DescStruct m <- structMembers desc]
        context "structNestedInterfaces" =
            MuList $ map (childContext context . interfaceName) [m | DescInterface m <- structMembers desc]
        context s = parent s

    typeContext parent desc = mkStrContext context where
        context "typeStructOrUnion" = case desc of
            DescStruct d -> muJust $ structContext context d
            DescUnion u -> muJust $ unionContext context u
            _ -> muNull
        context "typeEnum" = case desc of
            DescEnum d -> muJust $ enumContext context d
            _ -> muNull
        context "typeSchema" = case desc of
Kenton Varda's avatar
Kenton Varda committed
506
            DescUnion _ -> muNull
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
            _ -> muJust $ schemaContext context desc
        context s = parent s

    importContext parent ('/':filename) = mkStrContext context where
        context "importFilename" = MuVariable filename
        context "importIsSystem" = MuBool True
        context s = parent s
    importContext parent filename = mkStrContext context where
        context "importFilename" = MuVariable filename
        context "importIsSystem" = MuBool False
        context s = parent s

    namespaceContext parent part = mkStrContext context where
        context "namespaceName" = MuVariable part
        context s = parent s

    fileContext desc = mkStrContext context where
        flattenedMembers = flattenTypes $ catMaybes $ Map.elems $ fileMemberMap desc

        namespace = maybe [] (splitOn "::") $ fileNamespace desc

        isImportUsed (_, dep) = Set.member (fileName dep) (fileRuntimeImports desc)

        context "fileName" = MuVariable $ fileName desc
        context "fileBasename" = MuVariable $ takeBaseName $ fileName desc
        context "fileIncludeGuard" = MuVariable $
            "CAPNPROTO_INCLUDED_" ++ hashString (fileName desc ++ ':':show (fileId desc))
        context "fileNamespaces" = MuList $ map (namespaceContext context) namespace
        context "fileEnums" = MuList $ map (enumContext context) [e | DescEnum e <- fileMembers desc]
        context "fileTypes" = MuList $ map (typeContext context) flattenedMembers
        context "fileImports" = MuList $ map (importContext context . fst)
                              $ filter isImportUsed $ Map.toList $ fileImportMap desc
        context s = error ("Template variable not defined: " ++ s)
540 541

headerTemplate :: String
542 543 544 545
headerTemplate = ByteStringUTF8.toString $(embedFile "src/c++-header.mustache")

srcTemplate :: String
srcTemplate = ByteStringUTF8.toString $(embedFile "src/c++-source.mustache")
546 547 548 549 550 551 552 553 554 555 556

-- Sadly it appears that hashtache requires access to the IO monad, even when template inclusion
-- is disabled.
hastacheConfig :: MuConfig IO
hastacheConfig = MuConfig
    { muEscapeFunc = emptyEscape
    , muTemplateFileDir = Nothing
    , muTemplateFileExt = Nothing
    , muTemplateRead = \_ -> return Nothing
    }

557 558 559 560
generateCxxHeader file schemaNodes =
    hastacheStr hastacheConfig (encodeStr headerTemplate) (outerFileContext schemaNodes file)
generateCxxSource file schemaNodes =
    hastacheStr hastacheConfig (encodeStr srcTemplate) (outerFileContext schemaNodes file)
561

562
generateCxx files _ schemaNodes = do
563
    let handleFile file = do
564 565
            header <- generateCxxHeader file schemaNodes
            source <- generateCxxSource file schemaNodes
566 567 568
            return [(fileName file ++ ".h", header), (fileName file ++ ".c++", source)]
    results <- mapM handleFile files
    return $ concat results