Commit 8bb48c81 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Fix NG tensor type lowering (#29)

Element type was not lowered.
parent 4df55e63
......@@ -73,9 +73,3 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap
{
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
}
MemRefType NGTensorType::toMemref()
{
auto memRefType = MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
return memRefType;
}
......@@ -114,9 +114,6 @@ namespace mlir
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type
bool isSigned() const;
......@@ -163,8 +160,6 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// Convert to equivalent std type. Integer of width 1 in that case
mlir::Type toStdType() { return mlir::IntegerType::get(1, getContext()); }
};
// Note that dialect types don't add new data members, so always possible
......@@ -240,8 +235,6 @@ namespace mlir
// Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
}
/// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI
......
......@@ -185,7 +185,7 @@ namespace
auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()},
{m_dialectLowerer.convertType(tensorType)},
rewriter);
auto size = tensorType.getSizeInBytes();
......@@ -265,30 +265,36 @@ namespace
return callBackFuncPtr;
}
// NGDialect converters
Type DialectLowerer::convertType(Type t)
Type DialectLowerer::convertType(Type type)
{
if (auto tensor = t.dyn_cast<NGTensorType>())
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>())
{
return tensor.toMemref();
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return MemRefType::get(tensor_type.getShape(),
convertType(tensor_type.getElementType()),
{/* no map used */},
0);
}
// element type
if (auto type = t.dyn_cast<NGFloatType>())
if (auto float_type = type.dyn_cast<NGFloatType>())
{
// Float
// float types are already std type
return type;
// Float types are already std type.
return float_type;
}
if (auto type = t.dyn_cast<NGIntegerType>())
if (auto int_type = type.dyn_cast<NGIntegerType>())
{
// map it to std type
return type.toStdType();
return mlir::IntegerType::get(int_type.getWidth(), int_type.getContext());
}
if (auto type = t.dyn_cast<NGBoolType>())
if (auto bool_type = type.dyn_cast<NGBoolType>())
{
return type.toStdType();
return mlir::IntegerType::get(1 /* width */, bool_type.getContext());
}
NGRAPH_FAIL() << "Unsupported type to lower";
return t;
return type;
}
#define REWRITER(OP) \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment