Commit 8bb48c81 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Fix NG tensor type lowering (#29)

Element type was not lowered.
parent 4df55e63
...@@ -73,9 +73,3 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap ...@@ -73,9 +73,3 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap
{ {
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape); return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
} }
MemRefType NGTensorType::toMemref()
{
auto memRefType = MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
return memRefType;
}
...@@ -114,9 +114,6 @@ namespace mlir ...@@ -114,9 +114,6 @@ namespace mlir
/// Return the bitwidth of this integer type. /// Return the bitwidth of this integer type.
unsigned getWidth() const; unsigned getWidth() const;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type /// Check if signed type
bool isSigned() const; bool isSigned() const;
...@@ -163,8 +160,6 @@ namespace mlir ...@@ -163,8 +160,6 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; } static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); } static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// Convert to equivalent std type. Integer of width 1 in that case
mlir::Type toStdType() { return mlir::IntegerType::get(1, getContext()); }
}; };
// Note that dialect types don't add new data members, so always possible // Note that dialect types don't add new data members, so always possible
...@@ -240,8 +235,6 @@ namespace mlir ...@@ -240,8 +235,6 @@ namespace mlir
// Multiply times element size // Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8); return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
} }
/// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape. /// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape); static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI /// for llvm RTTI
......
...@@ -185,7 +185,7 @@ namespace ...@@ -185,7 +185,7 @@ namespace
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()}, {m_dialectLowerer.convertType(tensorType)},
rewriter); rewriter);
auto size = tensorType.getSizeInBytes(); auto size = tensorType.getSizeInBytes();
...@@ -265,30 +265,36 @@ namespace ...@@ -265,30 +265,36 @@ namespace
return callBackFuncPtr; return callBackFuncPtr;
} }
// NGDialect converters // NGDialect converters
Type DialectLowerer::convertType(Type t) Type DialectLowerer::convertType(Type type)
{ {
if (auto tensor = t.dyn_cast<NGTensorType>()) // We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>())
{ {
return tensor.toMemref(); // Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return MemRefType::get(tensor_type.getShape(),
convertType(tensor_type.getElementType()),
{/* no map used */},
0);
} }
// element type if (auto float_type = type.dyn_cast<NGFloatType>())
if (auto type = t.dyn_cast<NGFloatType>())
{ {
// Float // Float types are already std type.
// float types are already std type return float_type;
return type;
} }
if (auto type = t.dyn_cast<NGIntegerType>()) if (auto int_type = type.dyn_cast<NGIntegerType>())
{ {
// map it to std type return mlir::IntegerType::get(int_type.getWidth(), int_type.getContext());
return type.toStdType();
} }
if (auto type = t.dyn_cast<NGBoolType>()) if (auto bool_type = type.dyn_cast<NGBoolType>())
{ {
return type.toStdType(); return mlir::IntegerType::get(1 /* width */, bool_type.getContext());
} }
NGRAPH_FAIL() << "Unsupported type to lower"; NGRAPH_FAIL() << "Unsupported type to lower";
return t; return type;
} }
#define REWRITER(OP) \ #define REWRITER(OP) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment