Commit 5867666f authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Replace MatmulBiasOp with DotOp (#20)

* [MLIR] Replace MatmulBiasOp with DotOp

We disable CPUFusion if MLIR is enabled to avoid CPU specific ops to be
introduced for now.
parent 9acdfe04
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "compiler.hpp" #include "compiler.hpp"
#include "dialect/ops.hpp"
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <llvm/IR/Module.h> #include <llvm/IR/Module.h>
...@@ -34,12 +33,14 @@ ...@@ -34,12 +33,14 @@
#include <mlir/Transforms/Passes.h> #include <mlir/Transforms/Passes.h>
#include <mutex> #include <mutex>
#include "dialect/dialect.hpp" #include "dialect/dialect.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp" #include "dialect/type.hpp"
#include "lowerer.hpp" #include "lowerer.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -266,19 +267,15 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) ...@@ -266,19 +267,15 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op. NGRAPH_ASSERT(ng_node->get_arguments().size() == 2) << "Expected two operands in Dot operation";
// Keeping it simple for now. return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
return compiler.create_binary_op<mlir::NGMatMulBiasOp>(ng_node);
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>}, {TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
{TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}}; {TI(ngraph::op::Dot), &MLIRCompiler::create_op<ngraph::op::Dot>}};
template <typename BinOp> template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
......
...@@ -56,33 +56,16 @@ static mlir::LogicalResult verifyOp(T* op) ...@@ -56,33 +56,16 @@ static mlir::LogicalResult verifyOp(T* op)
return op->emitOpError("Unsupported verifier for this operation"); return op->emitOpError("Unsupported verifier for this operation");
} }
// Per op specializations
template <> template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op) mlir::LogicalResult verifyOp(NGDotOp* op)
{ {
// Verify that we have 2 operands mlir::LogicalResult result = verifyBinaryArithOp(op);
// Bias operand must be null for now (not implemented) if (failed(result))
if (op->getNumOperands() != 2)
{ {
std::stringstream ss; return result;
ss << "Unexpected MatmulBiasOp with " << op->getNumOperands()
<< " operands. 3 operands expected";
return op->emitOpError(ss.str());
} }
// Verify that operand types are supported. // TODO(dcab): Improve verification: proper shapes, etc.
auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>();
auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>();
// Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2)
{
return op->emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in "
"MatmulBiasOp");
}
// TODO(dcab): Improve verification: matching types, proper shapes, etc.
return mlir::success(); return mlir::success();
} }
......
...@@ -26,4 +26,5 @@ namespace mlir ...@@ -26,4 +26,5 @@ namespace mlir
{ {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.h.inc" #include "ops.h.inc"
#undef GET_OP_CLASSES
} }
...@@ -144,13 +144,14 @@ def NGNotEqOp : NG_OneResult_Op<"not.equal", [NoSideEffect]>; ...@@ -144,13 +144,14 @@ def NGNotEqOp : NG_OneResult_Op<"not.equal", [NoSideEffect]>;
def NGSelectOp : NG_OneResult_Op<"select", [NoSideEffect]>; def NGSelectOp : NG_OneResult_Op<"select", [NoSideEffect]>;
// Matrix Multiply // Matrix Multiply
def NGMatMulBiasOp : NG_Binary_Arith_Op<"matmul.bias"> def NGDotOp : NG_Binary_Arith_Op<"dot">
{ {
let verifier=[{return verifyOp(this);}]; // TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }];
} }
// Terminator Ops // Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">; def NGReturnOp : NG_Terminator_Op<"return">;
// Fake ops // Fake ops
def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>; def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
\ No newline at end of file
...@@ -56,8 +56,8 @@ namespace ...@@ -56,8 +56,8 @@ namespace
// Initialize the list of converters. // Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{ {
RewriteListBuilder<NGAddOpConversion, NGMatMulBiasOpConversion, NGReturnOpConversion>:: RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
build(patterns, mlirContext, m_pass); patterns, mlirContext, m_pass);
} }
private: private:
...@@ -335,19 +335,17 @@ namespace ...@@ -335,19 +335,17 @@ namespace
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
} }
REWRITER(NGMatMulBiasOp) REWRITER(NGDotOp)
{ {
auto matmul = cast<NGMatMulBiasOp>(op); auto dot = cast<NGDotOp>(op);
auto loc = matmul.getLoc(); auto loc = dot.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value* lhs = operands[0];
Value* rhs = operands[1]; Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp"; NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in DotOp";
auto result_ty = result->getType().dyn_cast<MemRefType>(); auto result_ty = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>(); auto lhs_ty = lhs->getType().dyn_cast<MemRefType>();
...@@ -358,7 +356,7 @@ namespace ...@@ -358,7 +356,7 @@ namespace
Type elem_ty = result_ty.getElementType(); Type elem_ty = result_ty.getElementType();
NGRAPH_ASSERT(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType()) NGRAPH_ASSERT(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType())
<< "Types mismatch in MatmulBiasOp"; << "Types mismatch in DotOp";
// Create the following loop nest for matmul operation: // Create the following loop nest for matmul operation:
// for(n, N, 1) // for(n, N, 1)
...@@ -371,7 +369,7 @@ namespace ...@@ -371,7 +369,7 @@ namespace
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs); IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2) NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2)
<< "MatmulBias operation is only supported for 2D tensors"; << "Dot operation is only supported for 2D tensors";
// Induction variables, lower bounds, upper bounds and steps of the loop nest. // Induction variables, lower bounds, upper bounds and steps of the loop nest.
IndexHandle n, m, k; IndexHandle n, m, k;
...@@ -384,22 +382,14 @@ namespace ...@@ -384,22 +382,14 @@ namespace
IndexedValue ires(result), ilhs(lhs), irhs(rhs); IndexedValue ires(result), ilhs(lhs), irhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
// clang-format off LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
LoopBuilder(&n, n_lb, n_ub, n_step)( LoopBuilder(&k, k_lb, k_ub, k_step)([&] {
[&]{ i_res(n, k) = zero_init;
LoopBuilder(&k, k_lb, k_ub, k_step)( LoopBuilder(&m, m_lb, m_ub, m_step)(
[&]{ [&] { i_res(n, k) += i_lhs(n, m) * i_rhs(m, k); });
i_res(n, k) = zero_init; });
LoopBuilder(&m, m_lb, m_ub, m_step)( });
[&]{
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k);
}
);
}
);
}
);
// clang-format on
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
} }
......
...@@ -30,7 +30,7 @@ public:\ ...@@ -30,7 +30,7 @@ public:\
}; };
DECL_OP_CONV(NGAddOp) DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGMatMulBiasOp) DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp) DECL_OP_CONV(NGReturnOp)
#undef DECL_OP_CONV #undef DECL_OP_CONV
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "contrib/mlir/compiler.hpp" #include "contrib/mlir/compiler.hpp"
#endif #endif
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
......
...@@ -1200,7 +1200,13 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1200,7 +1200,13 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS_WITH_ARGS( REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS); CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported); REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
// Disable CPUFusion if MLIR is enabled to preserve core ops.
if (std::getenv("NGRAPH_MLIR") == nullptr)
{
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
}
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment