dialect.cpp 5.26 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

17 18 19
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.

20
#include "dialect.hpp"
21
#include "ngraph/check.hpp"
22 23
#include "ops.hpp"
#include "type.hpp"
24

25 26
#include <mlir/Parser.h>

27
using namespace mlir;
28

29 30
NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
    : mlir::Dialect(getDialectNamespace(), ctx)
31
{
32 33 34
    addTypes<NGTensorType>();
    addTypes<NGIntegerType>();
    addTypes<NGBoolType>();
35

36 37 38 39
    addOperations<
#define GET_OP_LIST
#include "ops.cpp.inc"
        >();
40
}
41

42 43 44 45
mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location loc) const
{
    StringRef origTypeStr = tyData;
    MLIRContext* context = getContext();
46 47

    // Process nGraph tensor type.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    if (tyData.consume_front("tensor"))
    {
        if (!tyData.consume_front("<") || !tyData.consume_back(">"))
        {
            return (emitError(loc, "expected '<' and '>' enclosing the tensor shape: " + tyData),
                    Type());
        }

        // Get x-separated sub-strings.
        SmallVector<StringRef, 8> subStrings;
        tyData.split(subStrings, "x");

        // Parse shape dimensions.
        SmallVector<int64_t, 4> shape;
        for (unsigned i = 0, end = subStrings.size() - 1; i < end; ++i)
        {
            StringRef dimStr = subStrings[i];
            int64_t dim = -1;
            // NOTE: `consumeInteger` returns false if an integer was parsed successfully.
            if (dimStr.consumeInteger(/*Radix=*/10, dim) || !dimStr.empty())
            {
                return (
                    emitError(loc, "expected a list of '[0-9]+x' dimension specifiers: " + tyData),
                    Type());
            }

            shape.push_back(dim);
        }

77
        // Parse nGraph element type.
78 79 80 81 82 83 84 85 86
        auto elem_ty = mlir::parseType(subStrings.back(), context);
        if (!elem_ty)
        {
            return (emitError(loc, "Unexpected element type in tensor type: " + tyData), Type());
        }

        return NGTensorType::get(context, elem_ty, shape);
    }

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    // Process nGraph integer element types.
    if (tyData.startswith("i") || tyData.startswith("u"))
    {
        bool isSigned = tyData.consume_front("i");
        bool isUnsigned = tyData.consume_front("u");
        NGRAPH_CHECK(isSigned != isUnsigned, "nGraph integer cannot be signed and unsigned");

        unsigned width = 0;
        // NOTE: `consumeInteger` returns false if an integer was parsed successfully.
        if (tyData.consumeInteger(/*Radix=*/10, width) || width == 0 || !tyData.empty())
        {
            return (emitError(loc, "Unexpected nGraph integer type: " + origTypeStr), Type());
        }

        switch (width)
        {
        case 8:
            return isSigned ? NGIntegerType::getInt8(context) : NGIntegerType::getUInt8(context);
        case 16:
            return isSigned ? NGIntegerType::getInt16(context) : NGIntegerType::getUInt16(context);
        case 32:
            return isSigned ? NGIntegerType::getInt32(context) : NGIntegerType::getUInt32(context);
        case 64:
            return isSigned ? NGIntegerType::getInt64(context) : NGIntegerType::getUInt64(context);
        default:
            return (emitError(loc, "Unexpected width for nGraph integer type: " + origTypeStr),
                    Type());
        }
    }

    // nGraph reuses standard dialect floating point element types.
    NGRAPH_CHECK(!tyData.startswith("f"),
                 "Floating point types should be processed by standard parser");

    // NOTE: We may hit this error if the nGraph type is not yet supported in parser.
122 123 124
    return (emitError(loc, "Unknown nGraph type: " + origTypeStr), Type());
}

125
void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
126 127
{
    switch (type.getKind())
128
    {
129
    case NG_TENSOR_TYPE_ID:
130
    {
131
        os << "tensor<";
132 133
        auto tensorTy = type.cast<NGTensorType>();
        for (auto dim : tensorTy.getShape())
134
        {
135
            os << dim << 'x';
136
        }
137
        os << tensorTy.getElementType() << '>';
138 139 140 141 142 143 144 145 146 147 148
        return;
    }
    case NG_I8_TYPE_ID:
    case NG_I16_TYPE_ID:
    case NG_I32_TYPE_ID:
    case NG_I64_TYPE_ID:
    case NG_U8_TYPE_ID:
    case NG_U16_TYPE_ID:
    case NG_U32_TYPE_ID:
    case NG_U64_TYPE_ID:
    {
149 150
        auto intTy = type.cast<NGIntegerType>();
        os << "i" << intTy.getWidth();
151 152 153 154 155 156 157
        return;
    }
    case NG_BOOL_TYPE_ID:
    {
        os << "bool";
        return;
    }
158
    default: { NGRAPH_CHECK(false, "Incorrect type to print?");
159
    }
160 161
    }
}