Commit ba735a80 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Add MatmulBias op with basic support for simple matmuls (#8)

The following test should work now:
NGRAPH_MLIR_DUMP_ALL=1 NGRAPH_MLIR=1 test/unit-test '--gtest_filter=CPU.dot2d'
parent dd5c6fb6
...@@ -38,6 +38,14 @@ namespace ngraph ...@@ -38,6 +38,14 @@ namespace ngraph
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name()); auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name()); auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
// TODO: Quick hook for MLIR.
if (std::getenv("NGRAPH_MLIR") != nullptr)
{
functors.emplace_back(build_mlir_single_output_binary_op(
node, arg0_tensor, arg1_tensor, out0_tensor));
return;
}
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node); const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
const auto& arg0_shape = mm->get_a_shape(); const auto& arg0_shape = mm->get_a_shape();
......
...@@ -102,6 +102,7 @@ ...@@ -102,6 +102,7 @@
#include "ngraph/runtime/cpu/kernel/subtract.hpp" #include "ngraph/runtime/cpu/kernel/subtract.hpp"
#include "ngraph/runtime/cpu/kernel/tan.hpp" #include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp" #include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/mlir/compiler.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp" #include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/op/loop_kernel.hpp" #include "ngraph/runtime/cpu/op/loop_kernel.hpp"
...@@ -502,3 +503,28 @@ namespace ngraph ...@@ -502,3 +503,28 @@ namespace ngraph
} }
} }
} }
using namespace ngraph::runtime::cpu;
CPUKernelFunctor Builder::build_mlir_single_output_binary_op(const ngraph::Node* node,
void*& arg0_tensor,
void*& arg1_tensor,
void*& out_tensor)
{
// TODO: Remove m_ip/op_list construction out of MLIRCompiler.
auto functor = [&, node](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
std::vector<const Node*> nodelist = {node};
// MLIR requires a list of type-erased pointer to arguments. Our arguments
// are already pointers, so we need to pass a double pointer.
std::vector<void*> ptr_args = {arg0_tensor, arg1_tensor, out_tensor};
MLIRCompiler mlirc(nodelist, ptr_args);
// TODO: Decouple 'compile' and 'run' APIs. We want to be able to run the
// same jitted code on different arguments.
mlirc.compile_and_run();
};
return functor;
}
...@@ -403,6 +403,12 @@ namespace ngraph ...@@ -403,6 +403,12 @@ namespace ngraph
const std::vector<TensorViewWrapper>& out) const std::vector<TensorViewWrapper>& out)
{ {
} }
// TODO (dcab): Doc
static CPUKernelFunctor build_mlir_single_output_binary_op(const ngraph::Node* node,
void*& arg0_tensor,
void*& arg1_tensor,
void*& out_tensor);
}; };
} }
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/runtime/cpu/mlir/dialect/ops.hpp" #include "ngraph/runtime/cpu/mlir/dialect/ops.hpp"
#include "ngraph/runtime/cpu/mlir/dialect/type.hpp" #include "ngraph/runtime/cpu/mlir/dialect/type.hpp"
#include "ngraph/runtime/cpu/mlir/lowerer.hpp" #include "ngraph/runtime/cpu/mlir/lowerer.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
...@@ -58,7 +59,7 @@ namespace ngraph ...@@ -58,7 +59,7 @@ namespace ngraph
// Register any LLVM command line options // Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", ""); llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
} }
void MLIRCompiler::compile() void MLIRCompiler::compile_and_run()
{ {
build_module(); // MLIR gen build_module(); // MLIR gen
lower_dialect(); lower_dialect();
...@@ -277,8 +278,20 @@ namespace ngraph ...@@ -277,8 +278,20 @@ namespace ngraph
return compiler.create_binary_op<NG_AddOp>(ng_node); return compiler.create_binary_op<NG_AddOp>(ng_node);
} }
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias)
{
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op.
// Keeping it simple for now.
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
return compiler.create_binary_op<NG_MatmulBiasOp>(ng_node);
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>}}; {TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
{TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}};
template <typename BinOp> template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
...@@ -343,6 +356,11 @@ namespace ngraph ...@@ -343,6 +356,11 @@ namespace ngraph
llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. Note that it takes a null pass manager
// to make sure it won't run "default" passes on the MLIR that would trigger
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
// the module.
// Create an MLIR execution engine. Note that it takes a null pass manager // Create an MLIR execution engine. Note that it takes a null pass manager
// to make sure it won't run "default" passes on the MLIR that would trigger // to make sure it won't run "default" passes on the MLIR that would trigger
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles // a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
......
...@@ -49,8 +49,8 @@ namespace ngraph ...@@ -49,8 +49,8 @@ namespace ngraph
} }
static void init_mlir(); static void init_mlir();
// compiles and runs a subgraph in MLIR /// Compiles and runs a subgraph in MLIR.
void compile(); void compile_and_run();
private: private:
struct TensorInfo struct TensorInfo
......
...@@ -29,6 +29,7 @@ namespace ngraph ...@@ -29,6 +29,7 @@ namespace ngraph
{ {
addTypes<NGTensorType>(); addTypes<NGTensorType>();
addOperations<NG_AddOp>(); addOperations<NG_AddOp>();
addOperations<NG_MatmulBiasOp>();
addOperations<NG_ReturnOp>(); addOperations<NG_ReturnOp>();
addOperations<NG_FakeOutput>(); addOperations<NG_FakeOutput>();
} }
......
...@@ -99,6 +99,53 @@ namespace ngraph ...@@ -99,6 +99,53 @@ namespace ngraph
return mlir::success(); return mlir::success();
} }
void runtime::cpu::NG_MatmulBiasOp::build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Value* lhs,
mlir::Value* rhs)
{
state->types.push_back(lhs->getType());
state->operands.push_back(lhs);
state->operands.push_back(rhs);
}
mlir::LogicalResult runtime::cpu::NG_MatmulBiasOp::verify()
{
// Verify that we have 3 operands
if (getNumOperands() != 3)
{
std::stringstream ss;
ss << "Unexpected MatmulBiasOp with " << getNumOperands()
<< " operands. 3 operands expected";
return emitOpError(ss.str());
}
// Bias operand must be null for now (not implemented).
if (getOperand(2) != nullptr)
{
return emitOpError("Bias operand is not null in MatmulBiasOp");
}
// Verify that operand types are supported.
auto op0_tensor_ty = getOperand(0)->getType().dyn_cast<NGTensorType>();
auto op1_tensor_ty = getOperand(1)->getType().dyn_cast<NGTensorType>();
if (!op0_tensor_ty || !op1_tensor_ty)
{
return emitOpError("Unsupported non-tensor type in MatmulBiasOp");
}
// Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() == 2 && op1_tensor_ty.getRank() == 2)
{
return emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in MatmulBiasOp");
}
// TODO(dcab): Improve verification: matching types, proper shapes, etc.
return mlir::success();
}
void runtime::cpu::NG_ReturnOp::build(mlir::Builder* builder, void runtime::cpu::NG_ReturnOp::build(mlir::Builder* builder,
mlir::OperationState* state, mlir::OperationState* state,
std::vector<mlir::Value*> value_list) std::vector<mlir::Value*> value_list)
......
...@@ -66,6 +66,32 @@ namespace ngraph ...@@ -66,6 +66,32 @@ namespace ngraph
using Op::Op; using Op::Op;
}; };
// TODO(dcab): Doc
// TODO(dcab): Increase operands to 3 when supporting bias.
class NG_MatmulBiasOp : public mlir::Op<NG_MatmulBiasOp,
mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect>
{
public:
static llvm::StringRef getOperationName() { return "ng.matmul.bias"; }
/// Custom verification.
mlir::LogicalResult verify();
static void build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Value* lhs,
mlir::Value* rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value* getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value* getRHS() { return getOperand(1); }
/// Convenience accessor for bias operand.
mlir::Value* getBias() { return nullptr; } // TODO
/// Inherit constructor.
using Op::Op;
};
/// Return operations terminate blocks (and functions as well). They take a /// Return operations terminate blocks (and functions as well). They take a
/// single argument and the type must match the function return type. /// single argument and the type must match the function return type.
class NG_ReturnOp : public mlir::Op<NG_ReturnOp, class NG_ReturnOp : public mlir::Op<NG_ReturnOp,
...@@ -97,4 +123,4 @@ namespace ngraph ...@@ -97,4 +123,4 @@ namespace ngraph
}; };
} }
} }
} }
\ No newline at end of file
...@@ -55,8 +55,9 @@ namespace ...@@ -55,8 +55,9 @@ namespace
// Initialize the list of converters. // Initialize the list of converters.
llvm::DenseSet<DialectOpConversion*> initConverters(MLIRContext* context) override llvm::DenseSet<DialectOpConversion*> initConverters(MLIRContext* context) override
{ {
return ConversionListBuilder<NG_AddOpConversion, NG_ReturnOpConversion>::build( return ConversionListBuilder<NG_AddOpConversion,
&allocator, context, m_pass); NG_MatmulBiasOpConversion,
NG_ReturnOpConversion>::build(&allocator, context, m_pass);
} }
private: private:
...@@ -154,7 +155,7 @@ namespace ...@@ -154,7 +155,7 @@ namespace
auto result = m_pass.buildOutputDefs(op, rewriter)[0]; auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(result->getType().isa<MemRefType>()); NGRAPH_ASSERT(result->getType().isa<MemRefType>());
// NOte that builder's current function is still the original function body. // Note that builder's current function is still the original function body.
// use getBlock to get the new block instead. // use getBlock to get the new block instead.
// get new operands // get new operands
...@@ -181,6 +182,73 @@ namespace ...@@ -181,6 +182,73 @@ namespace
return {result}; return {result};
} }
SmallVector<Value*, 4> NG_MatmulBiasOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
auto matmul = op->cast<NG_MatmulBiasOp>();
auto loc = matmul.getLoc();
NGRAPH_ASSERT(!matmul.getBias() && operands.size() == 2)
<< "Bias is not supported yet in MatmulBias operation";
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp";
auto result_ty = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>();
auto rhs_ty = rhs->getType().dyn_cast<MemRefType>();
NGRAPH_ASSERT(result_ty) << "Unexpected non-memref result type";
NGRAPH_ASSERT(lhs_ty) << "Unexpected non-memref LHS type";
NGRAPH_ASSERT(rhs_ty) << "Unexpected non-memref RHS type";
Type elem_ty = result_ty.getElementType();
NGRAPH_ASSERT(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType())
<< "Types mismatch in MatmulBiasOp";
// Create the following loop nest for matmul operation:
// for(n, N, 1)
// for(m, M, 1)
// for(k, K, 1)
// res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs);
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2)
<< "MatmulBias operation is only supported for 2D tensors";
// Induction variables, lower bounds, upper bounds and steps of the loop nest.
IndexHandle n, m, k;
IndexHandle n_lb(v_lhs.lb(1)), m_lb(v_lhs.lb(0)), k_lb(v_rhs.lb(0));
IndexHandle n_ub(v_lhs.ub(1)), m_ub(v_lhs.ub(0)), k_ub(v_rhs.ub(0));
int64_t n_step = v_lhs.step(1), m_step = v_lhs.step(0), k_step = v_rhs.step(0);
// TODO (dcab): Assert on dims
// Constants, indexed values and indexes to be used inside the loop nest.
IndexedValue ires(result), ilhs(lhs), irhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
// clang-format off
LoopBuilder(&n, n_lb, n_ub, n_step)({
LoopBuilder(&k, k_lb, k_ub, k_step)({
i_res(n, k) = zero_init,
LoopBuilder(&m, m_lb, m_ub, m_step)({
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k)
})
}),
});
// clang-format on
// Return result memref.
return {result};
}
SmallVector<Value*, 4> NG_ReturnOpConversion::rewrite(Operation* op, SmallVector<Value*, 4> NG_ReturnOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
FuncBuilder& rewriter) const FuncBuilder& rewriter) const
......
...@@ -30,6 +30,7 @@ public:\ ...@@ -30,6 +30,7 @@ public:\
}; };
DECL_OP_CONV(NG_AddOp) DECL_OP_CONV(NG_AddOp)
DECL_OP_CONV(NG_MatmulBiasOp)
DECL_OP_CONV(NG_ReturnOp) DECL_OP_CONV(NG_ReturnOp)
#undef DECL_OP_CONV #undef DECL_OP_CONV
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment