Commit c31940d4 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Enable module verification. Fix FakeInput result type to memref (#26)

* Enable module verification. Fix FakeInput result type to memref

* style-apply
parent 6bb90e3c
...@@ -123,6 +123,10 @@ namespace ngraph ...@@ -123,6 +123,10 @@ namespace ngraph
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get()); m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
build_ng_dialect(); build_ng_dialect();
m_module->getFunctions().push_back(function.release()); m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify()))
{
NGRAPH_FAIL() << "Invalid module after lowering to NG dialect";
}
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr) if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{ {
m_module->dump(); m_module->dump();
...@@ -197,8 +201,12 @@ namespace ngraph ...@@ -197,8 +201,12 @@ namespace ngraph
mlir::PassManager pm; mlir::PassManager pm;
pm.addPass(createDialectLoweringPass(this)); pm.addPass(createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get()); pm.run(m_module.get());
if (failed(m_module->verify()))
{
NGRAPH_FAIL() << "Incorrect module after dialect lowering";
}
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr) if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{ {
m_module->dump(); m_module->dump();
......
...@@ -67,8 +67,9 @@ namespace ngraph ...@@ -67,8 +67,9 @@ namespace ngraph
template <> template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op) mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
{ {
// Verify that we have 3 operands // Verify that we have 2 operands
if (op->getNumOperands() != 3) // Bias operand must be null for now (not implemented)
if (op->getNumOperands() != 2)
{ {
std::stringstream ss; std::stringstream ss;
ss << "Unexpected MatmulBiasOp with " << op->getNumOperands() ss << "Unexpected MatmulBiasOp with " << op->getNumOperands()
...@@ -76,18 +77,12 @@ namespace ngraph ...@@ -76,18 +77,12 @@ namespace ngraph
return op->emitOpError(ss.str()); return op->emitOpError(ss.str());
} }
// Bias operand must be null for now (not implemented).
if (op->getOperand(2) != nullptr)
{
return op->emitOpError("Bias operand is not null in MatmulBiasOp");
}
// Verify that operand types are supported. // Verify that operand types are supported.
auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>(); auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>();
auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>(); auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>();
// Verify that operand shapes are supported. // Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() == 2 && op1_tensor_ty.getRank() == 2) if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2)
{ {
return op->emitOpError( return op->emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in " "Unsupported number of dimensions. Only 2D tensors are supported in "
......
...@@ -43,6 +43,9 @@ include "mlir/IR/OpBase.td" ...@@ -43,6 +43,9 @@ include "mlir/IR/OpBase.td"
def NG_TensorType : Type<CPred<"{0}.isa<ngraph::runtime::ngmlir::NGTensorType>()">, def NG_TensorType : Type<CPred<"{0}.isa<ngraph::runtime::ngmlir::NGTensorType>()">,
"NGraph Tensor Type">; "NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
// NGraph operation base class. // NGraph operation base class.
// Prepends "ng." to operation name // Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Op<string mnemonic, list<OpTrait> traits = []> :
...@@ -53,6 +56,10 @@ class NG_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -53,6 +56,10 @@ class NG_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> : class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {} NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {}
// Base for fake instructions defining MemRef values
class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_MemRefType:$res)> {}
// Operations producing no results // Operations producing no results
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> : class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs)> {} NG_Op<mnemonic, traits>, Results<(outs)> {}
...@@ -112,4 +119,4 @@ def NGMatMulBiasOp : NG_Binary_Arith_Op<"matmul.bias"> ...@@ -112,4 +119,4 @@ def NGMatMulBiasOp : NG_Binary_Arith_Op<"matmul.bias">
def NGReturnOp : NG_Terminator_Op<"return">; def NGReturnOp : NG_Terminator_Op<"return">;
// Fake ops // Fake ops
def NGFakeInputOp : NG_OneResult_Op<"fake.input", [NoSideEffect]>; def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
\ No newline at end of file \ No newline at end of file
...@@ -170,13 +170,14 @@ namespace ...@@ -170,13 +170,14 @@ namespace
if (it != outputMap.end()) if (it != outputMap.end())
{ {
unsigned argId = (*it).second; unsigned argId = (*it).second;
auto newResult = rewriter auto fakeOp = rewriter.create<ngmlir::NGFakeInputOp>(
.create<ngmlir::NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */ origResult->getType()) /* convert to lowered type */
) );
.getResult(); // Fake instrution is short-lived. Verify here.
fakeOp.verify();
auto newResult = fakeOp.getResult();
newResults.push_back(newResult); newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult; m_loweredOutputValues[argId] = newResult;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment