Commit 9bb2fad3 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Add NG integer type. Map float types to std types

parent 3bd00e23
...@@ -148,21 +148,32 @@ namespace ngraph ...@@ -148,21 +148,32 @@ namespace ngraph
{ {
case ngraph::element::Type_t::undefined: case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic: case ngraph::element::Type_t::dynamic:
case ngraph::element::Type_t::boolean: default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::bf16:
default: NGRAPH_ASSERT(false) << "MLIR: Unsupported NGraph types"; break; case ngraph::element::Type_t::bf16: return NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f32: return mlir::FloatType::getF32(&m_context);
case ngraph::element::Type_t::f64: return mlir::FloatType::getF64(&m_context); case ngraph::element::Type_t::f32: return NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::i8:
case ngraph::element::Type_t::u8: return mlir::IntegerType::get(8, &m_context); case ngraph::element::Type_t::f64: return NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::i16:
case ngraph::element::Type_t::u16: return mlir::IntegerType::get(16, &m_context); case ngraph::element::Type_t::i8: return NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::i32:
case ngraph::element::Type_t::u32: return mlir::IntegerType::get(32, &m_context); case ngraph::element::Type_t::u8:
case ngraph::element::Type_t::i64: case ngraph::element::Type_t::boolean: return NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::u64: return mlir::IntegerType::get(64, &m_context);
case ngraph::element::Type_t::i16: return NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i32: return NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::u32: return NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::i64: return NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return NGIntegerType::getUInt64(&m_context);
} }
NGRAPH_ASSERT(false) << "Unreachable"; NGRAPH_FAIL(); // Unreachable
return mlir::Type(); return mlir::Type();
} }
...@@ -378,8 +389,7 @@ namespace ngraph ...@@ -378,8 +389,7 @@ namespace ngraph
auto memRefType = type.dyn_cast<mlir::MemRefType>(); auto memRefType = type.dyn_cast<mlir::MemRefType>();
if (!memRefType) if (!memRefType)
return nullptr; return nullptr;
if (memRefType.getNumDynamicDims() != 0) NGRAPH_ASSERT(memRefType.getNumDynamicDims() == 0) << "No support for dynamic shapes";
NGRAPH_FAIL();
// We only use StaticFloatMemRef because that's what MLIR currently offers. // We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
......
...@@ -17,36 +17,58 @@ ...@@ -17,36 +17,58 @@
#include "dialect.hpp" #include "dialect.hpp"
#include "ops.hpp" #include "ops.hpp"
#include "type.hpp" #include "type.hpp"
namespace ngraph
using namespace ngraph::runtime::ngmlir;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx)
{ {
using namespace runtime::ngmlir; addTypes<NGTensorType>();
addTypes<NGIntegerType>();
addTypes<NGBoolType>();
addOperations<NG_AddOp>();
addOperations<NG_MatmulBiasOp>();
addOperations<NG_ReturnOp>();
addOperations<NG_FakeInput>();
}
/// Register a dialect and its types void NGDialect::printType(mlir::Type type, raw_ostream& os) const
/// Usage: {
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>(); switch (type.getKind())
NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx)
{ {
addTypes<NGTensorType>(); case NG_TENSOR_TYPE_ID:
addOperations<NG_AddOp>();
addOperations<NG_MatmulBiasOp>();
addOperations<NG_ReturnOp>();
addOperations<NG_FakeInput>();
}
void NGDialect::printType(mlir::Type type, raw_ostream& os) const
{ {
auto arrayTy = type.dyn_cast<NGTensorType>(); os << "tensor<";
if (!arrayTy) auto tensor_ty = type.cast<NGTensorType>();
{ for (auto dim : tensor_ty.getShape())
NGRAPH_ASSERT(0) << "Incorrect type to print?";
}
os << "tensor";
if (!arrayTy.getShape().empty())
{ {
os << "<"; os << dim << 'x';
mlir::interleaveComma(arrayTy.getShape(), os);
os << ">";
} }
os << tensor_ty.getElementType() << '>';
return;
}
case NG_I8_TYPE_ID:
case NG_I16_TYPE_ID:
case NG_I32_TYPE_ID:
case NG_I64_TYPE_ID:
case NG_U8_TYPE_ID:
case NG_U16_TYPE_ID:
case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID:
{
auto int_ty = type.cast<NGIntegerType>();
os << "i" << int_ty.getWidth();
return;
}
case NG_BOOL_TYPE_ID:
{
os << "bool";
return;
}
default: { NGRAPH_ASSERT(0) << "Incorrect type to print?";
}
} }
} }
...@@ -70,8 +70,8 @@ namespace ngraph ...@@ -70,8 +70,8 @@ namespace ngraph
} }
void runtime::ngmlir::NG_FakeInput::build(mlir::Builder* builder, void runtime::ngmlir::NG_FakeInput::build(mlir::Builder* builder,
mlir::OperationState* state, mlir::OperationState* state,
mlir::Type resultType) mlir::Type resultType)
{ {
state->types.push_back(std::move(resultType)); state->types.push_back(std::move(resultType));
} }
...@@ -83,9 +83,9 @@ namespace ngraph ...@@ -83,9 +83,9 @@ namespace ngraph
} }
void runtime::ngmlir::NG_AddOp::build(mlir::Builder* builder, void runtime::ngmlir::NG_AddOp::build(mlir::Builder* builder,
mlir::OperationState* state, mlir::OperationState* state,
mlir::Value* lhs, mlir::Value* lhs,
mlir::Value* rhs) mlir::Value* rhs)
{ {
state->types.push_back(lhs->getType()); state->types.push_back(lhs->getType());
state->operands.push_back(lhs); state->operands.push_back(lhs);
...@@ -100,9 +100,9 @@ namespace ngraph ...@@ -100,9 +100,9 @@ namespace ngraph
} }
void runtime::ngmlir::NG_MatmulBiasOp::build(mlir::Builder* builder, void runtime::ngmlir::NG_MatmulBiasOp::build(mlir::Builder* builder,
mlir::OperationState* state, mlir::OperationState* state,
mlir::Value* lhs, mlir::Value* lhs,
mlir::Value* rhs) mlir::Value* rhs)
{ {
state->types.push_back(lhs->getType()); state->types.push_back(lhs->getType());
state->operands.push_back(lhs); state->operands.push_back(lhs);
...@@ -147,8 +147,8 @@ namespace ngraph ...@@ -147,8 +147,8 @@ namespace ngraph
} }
void runtime::ngmlir::NG_ReturnOp::build(mlir::Builder* builder, void runtime::ngmlir::NG_ReturnOp::build(mlir::Builder* builder,
mlir::OperationState* state, mlir::OperationState* state,
std::vector<mlir::Value*> value_list) std::vector<mlir::Value*> value_list)
{ {
for (auto value : value_list) for (auto value : value_list)
{ {
......
...@@ -34,11 +34,46 @@ using llvm::Twine; ...@@ -34,11 +34,46 @@ using llvm::Twine;
namespace ngraph namespace ngraph
{ {
using namespace runtime::ngmlir; using namespace runtime::ngmlir;
unsigned NGIntegerType::getWidth() const
{
switch (getKind())
{
case NG_I8_TYPE_ID:
case NG_U8_TYPE_ID: return 8;
case NG_I16_TYPE_ID:
case NG_U16_TYPE_ID: return 16;
case NG_I32_TYPE_ID:
case NG_U32_TYPE_ID: return 32;
case NG_I64_TYPE_ID:
case NG_U64_TYPE_ID: return 64;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return 0;
}
bool NGIntegerType::isSigned() const
{
switch (getKind())
{
case NG_I8_TYPE_ID:
case NG_I16_TYPE_ID:
case NG_I32_TYPE_ID:
case NG_I64_TYPE_ID: return true;
case NG_U8_TYPE_ID:
case NG_U16_TYPE_ID:
case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID: return false;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return false;
}
/// Creates TensorType objects. They all point to the same storage if /// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same. /// element type and shape are the same.
NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape) NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape)
{ {
return Base::get(context, NGTypeKind::TENSOR_TYPE_ID, eltType, shape); return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
} }
mlir::MemRefType NGTensorType::toMemref() mlir::MemRefType NGTensorType::toMemref()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#pragma once #pragma once
#include "assertion.hpp"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
...@@ -22,7 +23,6 @@ ...@@ -22,7 +23,6 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -36,9 +36,146 @@ namespace ngraph ...@@ -36,9 +36,146 @@ namespace ngraph
// The enum starts at the range reserved for this dialect. // The enum starts at the range reserved for this dialect.
// These values are pre-defined in MLIR lib and not configurable from here. // These values are pre-defined in MLIR lib and not configurable from here.
NG_TYPE = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, NG_TYPE = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
TENSOR_TYPE_ID // Element types that are added by the dialect.
// Other types are just re-use of std dialect types.
NG_FIRST_INT_TYPE_ID,
NG_I8_TYPE_ID = NG_FIRST_INT_TYPE_ID,
NG_I16_TYPE_ID,
NG_I32_TYPE_ID,
NG_I64_TYPE_ID,
NG_U8_TYPE_ID,
NG_U16_TYPE_ID,
NG_U32_TYPE_ID,
NG_U64_TYPE_ID,
NG_LAST_INT_TYPE_ID = NG_U64_TYPE_ID,
NG_BOOL_TYPE_ID,
// Tensor type
NG_TENSOR_TYPE_ID
};
// reuse std float types as-is
using NGFloatType = mlir::FloatType;
/// Integer type. It represents an integer of width 8,16,32,64. Signed or not.
class NGIntegerType : public mlir::Type::TypeBase<NGIntegerType, mlir::Type>
{
public:
using Base::Base;
static NGIntegerType get(NGTypeKind kind, mlir::MLIRContext* context)
{
NGRAPH_ASSERT(kindof(kind)) << "Not an integer kind.";
return Base::get(context, kind);
}
/// Create signed Int8
static NGIntegerType getInt8(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I8_TYPE_ID, ctx);
}
/// Create signed Int16
static NGIntegerType getInt16(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I16_TYPE_ID, ctx);
}
/// Create signed Int32
static NGIntegerType getInt32(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I32_TYPE_ID, ctx);
}
/// Create signed Int64
static NGIntegerType getInt64(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I64_TYPE_ID, ctx);
}
/// Create unsigned Int8
static NGIntegerType getUInt8(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U8_TYPE_ID, ctx);
}
/// Create unsigned Int16
static NGIntegerType getUInt16(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U16_TYPE_ID, ctx);
}
/// Create unsigned Int32
static NGIntegerType getUInt32(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U32_TYPE_ID, ctx);
}
/// Create unsigned Int64
static NGIntegerType getUInt64(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U64_TYPE_ID, ctx);
}
/// RTTI support. So we can do obj->isa<NGIntegerType>()
static bool kindof(unsigned kind)
{
return kind >= NGTypeKind::NG_FIRST_INT_TYPE_ID &&
kind <= NGTypeKind::NG_LAST_INT_TYPE_ID;
}
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() const
{
return mlir::IntegerType::get(getWidth(), getContext());
}
/// Check if signed type
bool isSigned() const;
/// Check if Int8
bool isInt8() const { return getKind() == NG_I8_TYPE_ID; }
/// Check if UInt8
bool isUInt8() const { return getKind() == NG_U8_TYPE_ID; }
/// Check if Int16
bool isInt16() const { return getKind() == NG_I16_TYPE_ID; }
/// Check if UInt16
bool isUInt16() const { return getKind() == NG_U16_TYPE_ID; }
/// Check if Int32
bool isInt32() const { return getKind() == NG_I32_TYPE_ID; }
/// Check if UInt32
bool isUInt32() const { return getKind() == NG_U32_TYPE_ID; }
/// Check if Int64
bool isInt64() const { return getKind() == NG_I64_TYPE_ID; }
/// Check if UInt64
bool isUInt64() const { return getKind() == NG_U64_TYPE_ID; }
// Delete convenience methods inherited from MLIR Type class.
// This would avoid confusion if we do something like this and get false.
//
// if (type->cast<NGIntegerType>()->isInteger(32)) {}
//
// Those helpers use type id, and since we have our own Integer type id, they
// don't apply.
bool isInteger(unsigned width) const = delete;
unsigned getIntOrFloatBitWidth() const = delete;
bool isIntOrIndex() const = delete;
bool isIntOrIndexOrFloat() const = delete;
bool isIntOrFloat() const = delete;
};
/// Boolean Type.
class NGBoolType : public mlir::Type::TypeBase<NGBoolType, mlir::Type>
{
public:
using Base::Base;
static NGBoolType get(NGTypeKind kind, mlir::MLIRContext* context)
{
NGRAPH_ASSERT(kindof(kind)) << "Not a bool type.";
return Base::get(context, kind);
}
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// Convert to equivalent std type. Integer of width 1 in that case
mlir::Type toStdType() const { return mlir::IntegerType::get(1, getContext()); }
}; };
// Note that dialect types don't add new data members, so always possible
// to use NG or std types here
using EltType = mlir::Type; using EltType = mlir::Type;
// TODO: Can we use ngraph::shape here (given the hashing requirements) // TODO: Can we use ngraph::shape here (given the hashing requirements)
using Shape = llvm::ArrayRef<int64_t>; using Shape = llvm::ArrayRef<int64_t>;
...@@ -86,6 +223,7 @@ namespace ngraph ...@@ -86,6 +223,7 @@ namespace ngraph
Shape m_shape; Shape m_shape;
}; };
/// NGraph Tensor Type
class NGTensorType class NGTensorType
: public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage> : public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage>
{ {
...@@ -93,7 +231,9 @@ namespace ngraph ...@@ -93,7 +231,9 @@ namespace ngraph
using Base::Base; using Base::Base;
EltType getElementType() const { return getImpl()->getElementType(); } EltType getElementType() const { return getImpl()->getElementType(); }
Shape getShape() const { return getImpl()->getShape(); } Shape getShape() const { return getImpl()->getShape(); }
/// Tensor Rank. Static shape only for now
int getRank() { return getShape().size(); } int getRank() { return getShape().size(); }
/// Computes tensor size in bytes
size_t getSizeInBytes() size_t getSizeInBytes()
{ {
size_t s = 1; size_t s = 1;
...@@ -113,7 +253,7 @@ namespace ngraph ...@@ -113,7 +253,7 @@ namespace ngraph
/// create a unique tensor type based on element type and shape. /// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape); static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI /// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::TENSOR_TYPE_ID; } static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
}; };
} }
} }
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "lowerer.hpp" #include "lowerer.hpp"
#include <map> #include <map>
#include "compiler.hpp" #include "compiler.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Helpers.h"
...@@ -25,8 +27,6 @@ ...@@ -25,8 +27,6 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
// anonymous namespace // anonymous namespace
...@@ -272,6 +272,23 @@ namespace ...@@ -272,6 +272,23 @@ namespace
{ {
return tensor.toMemref(); return tensor.toMemref();
} }
// element type
if (auto type = t.dyn_cast<NGFloatType>())
{
// Float
// float types are already std type
return type;
}
if (auto type = t.dyn_cast<NGIntegerType>())
{
// map it to std type
return type.toStdType();
}
if (auto type = t.dyn_cast<NGBoolType>())
{
return type.toStdType();
}
NGRAPH_FAIL() << "Unsupported type to lower";
return t; return t;
} }
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <memory>
#include "memory_manager.hpp" #include "memory_manager.hpp"
#include <memory>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment