Commit 86bc31cc authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Move mlir related classes to MLIR namespace (#23)

* Move dialect and types to mlir namespace

* PR fixes and some cleanup

* Merge fix
parent ea441a6e
This diff is collapsed.
...@@ -18,11 +18,8 @@ ...@@ -18,11 +18,8 @@
#include "ops.hpp" #include "ops.hpp"
#include "type.hpp" #include "type.hpp"
using namespace ngraph::runtime::ngmlir; using namespace mlir;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect::NGDialect(mlir::MLIRContext* ctx) NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx) : mlir::Dialect("ng", ctx)
{ {
......
...@@ -23,24 +23,17 @@ ...@@ -23,24 +23,17 @@
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
namespace mlir
namespace ngraph
{ {
namespace runtime class NGDialect : public mlir::Dialect
{ {
namespace ngmlir public:
explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{ {
class NGDialect : public mlir::Dialect NGRAPH_ASSERT(0) << "Unsupported type parsing.";
{ return mlir::Type();
public:
explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{
NGRAPH_ASSERT(0) << "Unsupported type parsing.";
return mlir::Type();
}
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
} }
} void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
} }
...@@ -26,83 +26,69 @@ using llvm::raw_string_ostream; ...@@ -26,83 +26,69 @@ using llvm::raw_string_ostream;
using llvm::SmallVector; using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::Twine; using llvm::Twine;
using namespace mlir;
namespace ngraph // TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{ {
namespace runtime // TODO: Check matching element types
{ return mlir::success();
namespace ngmlir }
{
// TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
// TODO: Check matching element types
return mlir::success();
}
// Additional verification methods // Additional verification methods
// Tensor type checks are already verified by the caller of these methods // Tensor type checks are already verified by the caller of these methods
template <typename T> template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op) static mlir::LogicalResult verifyBinaryArithOp(T* op)
{ {
// TODO: Check matching element types // TODO: Check matching element types
return mlir::success(); return mlir::success();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyOp(T* op) static mlir::LogicalResult verifyOp(T* op)
{ {
return op->emitOpError("Unsupported verifier for this operation"); return op->emitOpError("Unsupported verifier for this operation");
} }
// Per op specializations // Per op specializations
template <> template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op) mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
{ {
// Verify that we have 2 operands // Verify that we have 2 operands
// Bias operand must be null for now (not implemented) // Bias operand must be null for now (not implemented)
if (op->getNumOperands() != 2) if (op->getNumOperands() != 2)
{ {
std::stringstream ss; std::stringstream ss;
ss << "Unexpected MatmulBiasOp with " << op->getNumOperands() ss << "Unexpected MatmulBiasOp with " << op->getNumOperands()
<< " operands. 3 operands expected"; << " operands. 3 operands expected";
return op->emitOpError(ss.str()); return op->emitOpError(ss.str());
} }
// Verify that operand types are supported. // Verify that operand types are supported.
auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>(); auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>();
auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>(); auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>();
// Verify that operand shapes are supported. // Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2) if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2)
{ {
return op->emitOpError( return op->emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in " "Unsupported number of dimensions. Only 2D tensors are supported in "
"MatmulBiasOp"); "MatmulBiasOp");
} }
// TODO(dcab): Improve verification: matching types, proper shapes, etc. // TODO(dcab): Improve verification: matching types, proper shapes, etc.
return mlir::success(); return mlir::success();
} }
}
}
using namespace mlir; namespace mlir
namespace runtime {
{
namespace ngmlir
{
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.cpp.inc" #include "ops.cpp.inc"
}
}
} }
...@@ -22,19 +22,8 @@ ...@@ -22,19 +22,8 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
namespace ngraph namespace mlir
{ {
namespace runtime
{
namespace ngmlir
{
// TODO: We shouldn't have this here, but we need to expose mlir types for the .inc file to use
// we cannot forward declare the mlir types since they rely on the Ops we are defining (see. Op<NGAddOp, ...>)
//
// Other ways to avoid namespace pollution ?
using namespace mlir;
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.h.inc" #include "ops.h.inc"
}
}
} }
...@@ -40,7 +40,7 @@ include "mlir/IR/OpBase.td" ...@@ -40,7 +40,7 @@ include "mlir/IR/OpBase.td"
// This defines records equivalent to NGraph types. It doesn't generate code. // This defines records equivalent to NGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs. // This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification // Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"{0}.isa<ngraph::runtime::ngmlir::NGTensorType>()">, def NG_TensorType : Type<CPred<"{0}.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">; "NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering // A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
......
...@@ -31,55 +31,51 @@ using llvm::SmallVector; ...@@ -31,55 +31,51 @@ using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::Twine; using llvm::Twine;
namespace ngraph using namespace mlir;
{
using namespace runtime::ngmlir;
unsigned NGIntegerType::getWidth() const unsigned NGIntegerType::getWidth() const
{
switch (getKind())
{ {
switch (getKind()) case NG_I8_TYPE_ID:
{ case NG_U8_TYPE_ID: return 8;
case NG_I8_TYPE_ID: case NG_I16_TYPE_ID:
case NG_U8_TYPE_ID: return 8; case NG_U16_TYPE_ID: return 16;
case NG_I16_TYPE_ID: case NG_I32_TYPE_ID:
case NG_U16_TYPE_ID: return 16; case NG_U32_TYPE_ID: return 32;
case NG_I32_TYPE_ID: case NG_I64_TYPE_ID:
case NG_U32_TYPE_ID: return 32; case NG_U64_TYPE_ID: return 64;
case NG_I64_TYPE_ID: default: NGRAPH_FAIL() << "Invalid type ID";
case NG_U64_TYPE_ID: return 64;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return 0;
} }
return 0;
}
bool NGIntegerType::isSigned() const bool NGIntegerType::isSigned() const
{
switch (getKind())
{ {
switch (getKind()) case NG_I8_TYPE_ID:
{ case NG_I16_TYPE_ID:
case NG_I8_TYPE_ID: case NG_I32_TYPE_ID:
case NG_I16_TYPE_ID: case NG_I64_TYPE_ID: return true;
case NG_I32_TYPE_ID: case NG_U8_TYPE_ID:
case NG_I64_TYPE_ID: return true; case NG_U16_TYPE_ID:
case NG_U8_TYPE_ID: case NG_U32_TYPE_ID:
case NG_U16_TYPE_ID: case NG_U64_TYPE_ID: return false;
case NG_U32_TYPE_ID: default: NGRAPH_FAIL() << "Invalid type ID";
case NG_U64_TYPE_ID: return false;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return false;
} }
return false;
}
/// Creates TensorType objects. They all point to the same storage if /// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same. /// element type and shape are the same.
NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape) NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shape)
{ {
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape); return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
} }
mlir::MemRefType NGTensorType::toMemref() MemRefType NGTensorType::toMemref()
{ {
auto memRefType = auto memRefType = MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
mlir::MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0); return memRefType;
return memRefType;
}
} }
This diff is collapsed.
...@@ -131,7 +131,7 @@ namespace ...@@ -131,7 +131,7 @@ namespace
// we find out output values by looking at returned values // we find out output values by looking at returned values
// any return should return all outputs of the subgraph // any return should return all outputs of the subgraph
f->walk<ngmlir::NGReturnOp>([this, &outputCount](ngmlir::NGReturnOp ret) { f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++) for (unsigned i = 0; i < ret.getNumOperands(); i++)
{ {
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i)); this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i));
...@@ -151,8 +151,8 @@ namespace ...@@ -151,8 +151,8 @@ namespace
// however, due to how DialectConversion framework works, new func is only // however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction() // materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction()
// will give you the original func). This makes it very convoluted to insert instructions at entry block. // will give you the original func). This makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<ngmlir::NGFakeInputOp>(rewriter->getUnknownLoc(), auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext())); IndexType::get(getModule().getContext()));
// will be fixed later to read passed arg instead. // will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult()); m_memMgrDefs.push_back(op.getResult());
return op.getResult(); return op.getResult();
...@@ -170,7 +170,7 @@ namespace ...@@ -170,7 +170,7 @@ namespace
if (it != outputMap.end()) if (it != outputMap.end())
{ {
unsigned argId = (*it).second; unsigned argId = (*it).second;
auto fakeOp = rewriter.create<ngmlir::NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */ origResult->getType()) /* convert to lowered type */
...@@ -183,7 +183,7 @@ namespace ...@@ -183,7 +183,7 @@ namespace
} }
else else
{ {
auto tensorType = origResult->getType().cast<ngmlir::NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()}, {tensorType.toMemref()},
...@@ -237,8 +237,7 @@ namespace ...@@ -237,8 +237,7 @@ namespace
for (auto value : m_loweredOutputValues) for (auto value : m_loweredOutputValues)
{ {
auto op = value->getDefiningOp(); auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<ngmlir::NGFakeInputOp>()) NGRAPH_ASSERT(op->isa<NGFakeInputOp>()) << "output value not defined by fake output?";
<< "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i)); value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase(); op->erase();
i++; i++;
...@@ -269,23 +268,23 @@ namespace ...@@ -269,23 +268,23 @@ namespace
// NGDialect converters // NGDialect converters
Type DialectLowerer::convertType(Type t) Type DialectLowerer::convertType(Type t)
{ {
if (auto tensor = t.dyn_cast<ngmlir::NGTensorType>()) if (auto tensor = t.dyn_cast<NGTensorType>())
{ {
return tensor.toMemref(); return tensor.toMemref();
} }
// element type // element type
if (auto type = t.dyn_cast<ngmlir::NGFloatType>()) if (auto type = t.dyn_cast<NGFloatType>())
{ {
// Float // Float
// float types are already std type // float types are already std type
return type; return type;
} }
if (auto type = t.dyn_cast<ngmlir::NGIntegerType>()) if (auto type = t.dyn_cast<NGIntegerType>())
{ {
// map it to std type // map it to std type
return type.toStdType(); return type.toStdType();
} }
if (auto type = t.dyn_cast<ngmlir::NGBoolType>()) if (auto type = t.dyn_cast<NGBoolType>())
{ {
return type.toStdType(); return type.toStdType();
} }
...@@ -298,7 +297,7 @@ namespace ...@@ -298,7 +297,7 @@ namespace
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
FuncBuilder& rewriter) const FuncBuilder& rewriter) const
{ {
auto add = op->cast<ngmlir::NGAddOp>(); auto add = op->cast<NGAddOp>();
auto loc = add.getLoc(); auto loc = add.getLoc();
Value *origResult, *newResult; Value *origResult, *newResult;
...@@ -335,7 +334,7 @@ namespace ...@@ -335,7 +334,7 @@ namespace
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
FuncBuilder& rewriter) const FuncBuilder& rewriter) const
{ {
auto matmul = op->cast<ngmlir::NGMatMulBiasOp>(); auto matmul = op->cast<NGMatMulBiasOp>();
auto loc = matmul.getLoc(); auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation"; NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
...@@ -406,16 +405,10 @@ namespace ...@@ -406,16 +405,10 @@ namespace
} }
} }
namespace ngraph namespace mlir
{ {
namespace runtime Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler)
{ {
namespace ngmlir return new DialectLoweringPass(*compiler);
{
Pass* createDialectLoweringPass(MLIRCompiler* compiler)
{
return new DialectLoweringPass(*compiler);
}
}
} }
} }
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#pragma once #pragma once
#include "contrib/mlir/compiler.hpp"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -26,8 +26,10 @@ namespace ngraph ...@@ -26,8 +26,10 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
class MLIRCompiler; class MLIRCompiler;
mlir::Pass* createDialectLoweringPass(MLIRCompiler* compiler);
} }
} }
} }
namespace mlir
{
mlir::Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler);
}
...@@ -22,7 +22,7 @@ class OP##Conversion : public mlir::DialectOpConversion \ ...@@ -22,7 +22,7 @@ class OP##Conversion : public mlir::DialectOpConversion \
{\ {\
public:\ public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\ explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(ngraph::runtime::ngmlir::OP::getOperationName(), 1, context),\ : mlir::DialectOpConversion(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\ m_pass(pass)\
{} \ {} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment