Commit 64b43082 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[WIP] Add ArgMin lowering support

parent 1b2b7d59
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
...@@ -108,12 +110,12 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -108,12 +110,12 @@ void MLIRCompiler::build_ng_dialect_module()
for (auto input : kernel_inputs) for (auto input : kernel_inputs)
{ {
args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get())); args_type_list.push_back(get_mlir_type(input.get()));
} }
for (auto output : kernel_outputs) for (auto output : kernel_outputs)
{ {
result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get())); result_type_list.push_back(get_mlir_type(output.get()));
} }
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context); auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
...@@ -144,17 +146,23 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -144,17 +146,23 @@ void MLIRCompiler::build_ng_dialect_module()
dump_mlir_module("nGraph Dialect Dump:"); dump_mlir_module("nGraph Dialect Dump:");
} }
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's // Converts nGraph shape \p ng_shape to MLIR shape \p mlir_shape.
// element type. static void get_mlir_shape(ngraph::Shape ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{ {
SmallVector<int64_t, 4> shape; for (auto dim : ng_shape)
for (auto d : tensor->get_shape())
{ {
shape.push_back(d); mlir_shape.push_back(dim);
} }
}
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape); // Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
} }
// Converts an nGraph element type into an MLIR type. // Converts an nGraph element type into an MLIR type.
...@@ -193,6 +201,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -193,6 +201,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif #endif
} }
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node)
{
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get();
if (TI(*node) == TI(ngraph::op::ArgMin))
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(out_tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(&m_context, mlir::IndexType::get(&m_context), mlir_shape);
}
return get_mlir_type(out_tensor);
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value) void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{ {
NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(), NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
...@@ -272,6 +294,25 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) ...@@ -272,6 +294,25 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
return compiler.create_binary_op<mlir::NGAddOp>(ng_node); return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
} }
template<>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto arg = idx_red->get_argument(0);
size_t red_axis = idx_red->get_reduction_axis();
mlir::Value* arg_val = compiler.get_tensor_value(arg->get_output_tensor_ptr().get()).m_value;
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)red_axis});
return compiler.m_builder
->create<mlir::NGArgMinRedOp>(mlir::UnknownLoc::get(&compiler.m_context),
compiler.get_mlir_type(ng_node),
arg_val,
red_axes_attr)
.getResult();
}
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
......
...@@ -91,6 +91,8 @@ namespace ngraph ...@@ -91,6 +91,8 @@ namespace ngraph
mlir::Type get_mlir_type(const descriptor::Tensor* tensor); mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type); mlir::Type get_mlir_type(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor); TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value); void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
......
...@@ -97,7 +97,7 @@ template <typename T> ...@@ -97,7 +97,7 @@ template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op) static mlir::LogicalResult verifyIndexReductionOp(T* op)
{ {
// TODO: verifyAxisReductionOp(op) + return element type + single axis. // TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::failure(); return mlir::success();
} }
template <typename T> template <typename T>
......
//***************************************************************************** //*****************************************************************************
// Copyright 2017-2019 Intel Corporation // Copyright 2019 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -37,7 +37,9 @@ namespace ...@@ -37,7 +37,9 @@ namespace
{ {
using namespace mlir; using namespace mlir;
using namespace mlir::edsc; using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir;
class DialectLoweringPass; class DialectLoweringPass;
...@@ -59,8 +61,10 @@ namespace ...@@ -59,8 +61,10 @@ namespace
// Initialize the list of converters. // Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{ {
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build( RewriteListBuilder<NGAddOpConversion,
patterns, mlirContext, m_pass); NGArgMinRedOpConversion,
NGDotOpConversion,
NGReturnOpConversion>::build(patterns, mlirContext, m_pass);
} }
private: private:
...@@ -383,7 +387,7 @@ namespace ...@@ -383,7 +387,7 @@ namespace
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim)); IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim); int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim);
// Constants, indexed values and indexes to be used inside the loop nest. // Constants and indexed values to be used inside the loop nest.
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs); IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
...@@ -398,6 +402,67 @@ namespace ...@@ -398,6 +402,67 @@ namespace
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
} }
REWRITER(NGArgMinRedOp)
{
auto argmin = cast<NGArgMinRedOp>(op);
auto loc = argmin.getLoc();
NGRAPH_ASSERT(operands.size() == 1 && operands[0] != nullptr)
<< "Expected one non-null operand in ArgMin op";
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
NGRAPH_ASSERT(arg_type.getRank() == 2) << "Unsupported tensor type in ArgMin op";
//axis = op->getAttr();
//NGRAPH_ASSERT(axis == 0) << "Unsupported axis in ArgMin op";
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
//NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp";
// FIXME: Workaround to the integer to index conversion.
auto res_ty = result->getType().cast<MemRefType>();
Type res_elem_ty = res_ty.getElementType();
//result->setType(
// MemRefType::get(res_ty.getShape(), IndexType::get(res_elem_ty.getContext())));
// Create the following loop nest for argmin operation:
// for(i, I, 1)
// for(j, J, 1) // Reduction dimention
// res[j] = select((arg[i, j] < res[j]), i, res[j])
MemRefView v_res(result), v_arg(arg);
unsigned n_dim = v_arg.fastestVarying() - 1;
unsigned m_dim = v_arg.fastestVarying();
// Constants, indexed values and other vars to be used inside the loop nest.
IndexedValue i_res(result), i_arg(arg);
// Initialize result to zero.
IndexHandle m_init;
IndexHandle m_lb_init(v_arg.lb(m_dim));
IndexHandle m_ub_init(v_arg.ub(m_dim));
int64_t m_step = v_arg.step(m_dim);
LoopBuilder(&m_init, m_lb_init, m_ub_init, m_step)([&] { i_res(m_init) = m_lb_init; });
// Main loop nest for argmin
IndexHandle n, m;
IndexHandle n_lb(v_arg.lb(n_dim)), m_lb(v_arg.lb(m_dim));
IndexHandle n_ub(v_arg.ub(n_dim)), m_ub(v_arg.ub(m_dim));
ValueHandle curr_res(res_elem_ty);
int64_t n_step = v_arg.step(n_dim);
LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
LoopBuilder(&m, m_lb, m_ub, m_step)([&] {
curr_res = i_res(m);
i_res(m) = edsc::intrinsics::select(i_arg(n, m) < i_arg(curr_res, m), n, curr_res);
});
});
rewriter.replaceOp(op, {result});
}
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); } REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER #undef REWRITER
} }
......
...@@ -30,6 +30,7 @@ public:\ ...@@ -30,6 +30,7 @@ public:\
}; };
DECL_OP_CONV(NGAddOp) DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGArgMinRedOp)
DECL_OP_CONV(NGDotOp) DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp) DECL_OP_CONV(NGReturnOp)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#endif #endif
MLIR_OP(Add) MLIR_OP(Add)
MLIR_OP(ArgMin)
MLIR_OP(Dot) MLIR_OP(Dot)
// Add new supported ops here // Add new supported ops here
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
//***************************************************************************** //*****************************************************************************
#include "mlir_subgraph_extraction.hpp" #include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
......
...@@ -55,6 +55,25 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial) ...@@ -55,6 +55,25 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result)); EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial_i32)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64) NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{ {
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3) Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment