Commit 209e3ccc authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Add support for dot with non-square tensor operands (#33)

It extends dot definition to be able to deal with operands that are not
square tensors. It also fixes a bug in the lowerer related to that.
parent 5867666f
...@@ -284,7 +284,9 @@ mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) ...@@ -284,7 +284,9 @@ mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr(); auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value; auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value; auto rhs_v = get_tensor_value(rhs.get()).m_value;
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), lhs_v, rhs_v).getResult(); auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v, rhs_v)
.getResult();
} }
void MLIRCompiler::create_return() void MLIRCompiler::create_return()
......
...@@ -100,6 +100,7 @@ namespace ngraph ...@@ -100,6 +100,7 @@ namespace ngraph
template <typename BinOp> template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node); mlir::Value* create_binary_op(const ngraph::Node* ng_node);
void create_return(); void create_return();
/// Helper to create memref arguments for MLIR function signature /// Helper to create memref arguments for MLIR function signature
......
...@@ -84,10 +84,9 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -84,10 +84,9 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
let verifier = [{ return verifyUnaryArithOp(this); }]; let verifier = [{ return verifyUnaryArithOp(this); }];
} }
// Arithmetic binary operations // Base class for arithmetic binary operations without side effects.
// Inputs and outputs have same type
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect, SameValueType], traits)>, NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
...@@ -123,14 +122,14 @@ def NGTanhOp : NG_Unary_Arith_Op<"tanh">; ...@@ -123,14 +122,14 @@ def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">; def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
// Binary Operations // Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>; def NGAddOp : NG_Binary_Arith_Op<"add", [SameValueType, Commutative]>;
def NGAndOp : NG_Binary_Arith_Op<"and", [Commutative]>; def NGAndOp : NG_Binary_Arith_Op<"and", [SameValueType, Commutative]>;
def NGSubOp : NG_Binary_Arith_Op<"sub">; def NGSubOp : NG_Binary_Arith_Op<"sub", [SameValueType]>;
def NGDivOp : NG_Binary_Arith_Op<"div">; def NGDivOp : NG_Binary_Arith_Op<"div", [SameValueType]>;
def NGMaxOp : NG_Binary_Arith_Op<"max", [Commutative]>; def NGMaxOp : NG_Binary_Arith_Op<"max", [SameValueType, Commutative]>;
def NGMinOp : NG_Binary_Arith_Op<"min", [Commutative]>; def NGMinOp : NG_Binary_Arith_Op<"min", [SameValueType, Commutative]>;
def NGMulOp : NG_Binary_Arith_Op<"mul", [Commutative]>; def NGMulOp : NG_Binary_Arith_Op<"mul", [SameValueType, Commutative]>;
def NGPowOp : NG_Binary_Arith_Op<"pow">; def NGPowOp : NG_Binary_Arith_Op<"pow", [SameValueType]>;
// Comparison // Comparison
def NGEqOp : NG_OneResult_Op<"equal", [NoSideEffect]>; def NGEqOp : NG_OneResult_Op<"equal", [NoSideEffect]>;
......
...@@ -366,20 +366,24 @@ namespace ...@@ -366,20 +366,24 @@ namespace
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout. // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs); MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs);
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2) NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2)
<< "Dot operation is only supported for 2D tensors"; << "Dot operation is only supported for 2D tensors";
// Induction variables, lower bounds, upper bounds and steps of the loop nest. // Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// one.
IndexHandle n, m, k; IndexHandle n, m, k;
IndexHandle n_lb(v_lhs.lb(1)), m_lb(v_lhs.lb(0)), k_lb(v_rhs.lb(0)); unsigned n_dim = v_lhs.fastestVarying() - 1;
IndexHandle n_ub(v_lhs.ub(1)), m_ub(v_lhs.ub(0)), k_ub(v_rhs.ub(0)); unsigned m_dim = v_rhs.fastestVarying();
int64_t n_step = v_lhs.step(1), m_step = v_lhs.step(0), k_step = v_rhs.step(0); unsigned k_dim = v_rhs.fastestVarying();
// TODO (dcab): Assert on dims IndexHandle n_lb(v_lhs.lb(n_dim)), m_lb(v_lhs.lb(m_dim)), k_lb(v_rhs.lb(k_dim));
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim);
// Constants, indexed values and indexes to be used inside the loop nest. // Constants, indexed values and indexes to be used inside the loop nest.
IndexedValue ires(result), ilhs(lhs), irhs(rhs); IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
LoopBuilder(&n, n_lb, n_ub, n_step)([&] { LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
......
...@@ -421,6 +421,32 @@ NGRAPH_TEST(${BACKEND_NAME}, dot2d) ...@@ -421,6 +421,32 @@ NGRAPH_TEST(${BACKEND_NAME}, dot2d)
EXPECT_TRUE(test::all_close_f((vector<float>{19, 22, 43, 50}), read_vector<float>(result))); EXPECT_TRUE(test::all_close_f((vector<float>{19, 22, 43, 50}), read_vector<float>(result)));
} }
NGRAPH_TEST(${BACKEND_NAME}, dot2d_non_square)
{
Shape shape_in1{2, 3};
Shape shape_in2{3, 3};
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto dot = make_shared<op::Dot>(A, B);
auto f = make_shared<Function>(dot, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape_in1);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape_in2);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape_out);
copy_data(a, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
copy_data(b, vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
vector<float>{30.f, 36.f, 42.f, 66.f, 81.f, 96.f}));
}
// //
// Here is what numpy does: // Here is what numpy does:
// //
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment