Commit 5867666f authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Replace MatmulBiasOp with DotOp (#20)

* [MLIR] Replace MatmulBiasOp with DotOp

We disable CPUFusion if MLIR is enabled to avoid CPU specific ops to be
introduced for now.
parent 9acdfe04
......@@ -14,7 +14,6 @@
// limitations under the License.
//*****************************************************************************
#include "compiler.hpp"
#include "dialect/ops.hpp"
#include <llvm/ADT/STLExtras.h>
#include <llvm/IR/Module.h>
......@@ -34,12 +33,14 @@
#include <mlir/Transforms/Passes.h>
#include <mutex>
#include "dialect/dialect.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
#include "lowerer.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -266,19 +267,15 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op.
// Keeping it simple for now.
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
return compiler.create_binary_op<mlir::NGMatMulBiasOp>(ng_node);
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2) << "Expected two operands in Dot operation";
return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
{TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}};
{TI(ngraph::op::Dot), &MLIRCompiler::create_op<ngraph::op::Dot>}};
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
......
......@@ -56,33 +56,16 @@ static mlir::LogicalResult verifyOp(T* op)
return op->emitOpError("Unsupported verifier for this operation");
}
// Per op specializations
template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
mlir::LogicalResult verifyOp(NGDotOp* op)
{
// Verify that we have 2 operands
// Bias operand must be null for now (not implemented)
if (op->getNumOperands() != 2)
mlir::LogicalResult result = verifyBinaryArithOp(op);
if (failed(result))
{
std::stringstream ss;
ss << "Unexpected MatmulBiasOp with " << op->getNumOperands()
<< " operands. 3 operands expected";
return op->emitOpError(ss.str());
return result;
}
// Verify that operand types are supported.
auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>();
auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>();
// Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2)
{
return op->emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in "
"MatmulBiasOp");
}
// TODO(dcab): Improve verification: matching types, proper shapes, etc.
// TODO(dcab): Improve verification: proper shapes, etc.
return mlir::success();
}
......
......@@ -26,4 +26,5 @@ namespace mlir
{
#define GET_OP_CLASSES
#include "ops.h.inc"
#undef GET_OP_CLASSES
}
......@@ -144,13 +144,14 @@ def NGNotEqOp : NG_OneResult_Op<"not.equal", [NoSideEffect]>;
def NGSelectOp : NG_OneResult_Op<"select", [NoSideEffect]>;
// Matrix Multiply
def NGMatMulBiasOp : NG_Binary_Arith_Op<"matmul.bias">
def NGDotOp : NG_Binary_Arith_Op<"dot">
{
let verifier=[{return verifyOp(this);}];
// TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }];
}
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
// Fake ops
def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
\ No newline at end of file
def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
......@@ -56,8 +56,8 @@ namespace
// Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
RewriteListBuilder<NGAddOpConversion, NGMatMulBiasOpConversion, NGReturnOpConversion>::
build(patterns, mlirContext, m_pass);
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, mlirContext, m_pass);
}
private:
......@@ -335,19 +335,17 @@ namespace
rewriter.replaceOp(op, {result});
}
REWRITER(NGMatMulBiasOp)
REWRITER(NGDotOp)
{
auto matmul = cast<NGMatMulBiasOp>(op);
auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
auto dot = cast<NGDotOp>(op);
auto loc = dot.getLoc();
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp";
NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in DotOp";
auto result_ty = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>();
......@@ -358,7 +356,7 @@ namespace
Type elem_ty = result_ty.getElementType();
NGRAPH_ASSERT(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType())
<< "Types mismatch in MatmulBiasOp";
<< "Types mismatch in DotOp";
// Create the following loop nest for matmul operation:
// for(n, N, 1)
......@@ -371,7 +369,7 @@ namespace
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2)
<< "MatmulBias operation is only supported for 2D tensors";
<< "Dot operation is only supported for 2D tensors";
// Induction variables, lower bounds, upper bounds and steps of the loop nest.
IndexHandle n, m, k;
......@@ -384,22 +382,14 @@ namespace
IndexedValue ires(result), ilhs(lhs), irhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
// clang-format off
LoopBuilder(&n, n_lb, n_ub, n_step)(
[&]{
LoopBuilder(&k, k_lb, k_ub, k_step)(
[&]{
i_res(n, k) = zero_init;
LoopBuilder(&m, m_lb, m_ub, m_step)(
[&]{
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k);
}
);
}
);
}
);
// clang-format on
LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
LoopBuilder(&k, k_lb, k_ub, k_step)([&] {
i_res(n, k) = zero_init;
LoopBuilder(&m, m_lb, m_ub, m_step)(
[&] { i_res(n, k) += i_lhs(n, m) * i_rhs(m, k); });
});
});
rewriter.replaceOp(op, {result});
}
......
......@@ -30,7 +30,7 @@ public:\
};
DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGMatMulBiasOp)
DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp)
#undef DECL_OP_CONV
......@@ -29,7 +29,6 @@
#include "contrib/mlir/compiler.hpp"
#endif
using namespace ngraph;
using namespace std;
......
......@@ -1200,7 +1200,13 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
// Disable CPUFusion if MLIR is enabled to preserve core ops.
if (std::getenv("NGRAPH_MLIR") == nullptr)
{
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
}
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment