Commit 64b43082 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[WIP] Add ArgMin lowering support

parent 1b2b7d59
......@@ -24,8 +24,10 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
#include <llvm/ADT/STLExtras.h>
......@@ -108,12 +110,12 @@ void MLIRCompiler::build_ng_dialect_module()
for (auto input : kernel_inputs)
{
args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
args_type_list.push_back(get_mlir_type(input.get()));
}
for (auto output : kernel_outputs)
{
result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
result_type_list.push_back(get_mlir_type(output.get()));
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
......@@ -144,17 +146,23 @@ void MLIRCompiler::build_ng_dialect_module()
dump_mlir_module("nGraph Dialect Dump:");
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
// Converts nGraph shape \p ng_shape to MLIR shape \p mlir_shape.
static void get_mlir_shape(ngraph::Shape ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
{
SmallVector<int64_t, 4> shape;
for (auto d : tensor->get_shape())
for (auto dim : ng_shape)
{
shape.push_back(d);
mlir_shape.push_back(dim);
}
}
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
}
// Converts an nGraph element type into an MLIR type.
......@@ -193,6 +201,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif
}
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node)
{
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get();
if (TI(*node) == TI(ngraph::op::ArgMin))
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(out_tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(&m_context, mlir::IndexType::get(&m_context), mlir_shape);
}
return get_mlir_type(out_tensor);
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{
NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
......@@ -272,6 +294,25 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
}
template<>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto arg = idx_red->get_argument(0);
size_t red_axis = idx_red->get_reduction_axis();
mlir::Value* arg_val = compiler.get_tensor_value(arg->get_output_tensor_ptr().get()).m_value;
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)red_axis});
return compiler.m_builder
->create<mlir::NGArgMinRedOp>(mlir::UnknownLoc::get(&compiler.m_context),
compiler.get_mlir_type(ng_node),
arg_val,
red_axes_attr)
.getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
......
......@@ -91,6 +91,8 @@ namespace ngraph
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
......
......@@ -97,7 +97,7 @@ template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op)
{
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::failure();
return mlir::success();
}
template <typename T>
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
// Copyright 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -37,7 +37,9 @@ namespace
{
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir;
class DialectLoweringPass;
......@@ -59,8 +61,10 @@ namespace
// Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, mlirContext, m_pass);
RewriteListBuilder<NGAddOpConversion,
NGArgMinRedOpConversion,
NGDotOpConversion,
NGReturnOpConversion>::build(patterns, mlirContext, m_pass);
}
private:
......@@ -383,7 +387,7 @@ namespace
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim);
// Constants, indexed values and indexes to be used inside the loop nest.
// Constants and indexed values to be used inside the loop nest.
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
......@@ -398,6 +402,67 @@ namespace
rewriter.replaceOp(op, {result});
}
REWRITER(NGArgMinRedOp)
{
auto argmin = cast<NGArgMinRedOp>(op);
auto loc = argmin.getLoc();
NGRAPH_ASSERT(operands.size() == 1 && operands[0] != nullptr)
<< "Expected one non-null operand in ArgMin op";
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
NGRAPH_ASSERT(arg_type.getRank() == 2) << "Unsupported tensor type in ArgMin op";
//axis = op->getAttr();
//NGRAPH_ASSERT(axis == 0) << "Unsupported axis in ArgMin op";
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
//NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp";
// FIXME: Workaround to the integer to index conversion.
auto res_ty = result->getType().cast<MemRefType>();
Type res_elem_ty = res_ty.getElementType();
//result->setType(
// MemRefType::get(res_ty.getShape(), IndexType::get(res_elem_ty.getContext())));
// Create the following loop nest for argmin operation:
// for(i, I, 1)
// for(j, J, 1) // Reduction dimention
// res[j] = select((arg[i, j] < res[j]), i, res[j])
MemRefView v_res(result), v_arg(arg);
unsigned n_dim = v_arg.fastestVarying() - 1;
unsigned m_dim = v_arg.fastestVarying();
// Constants, indexed values and other vars to be used inside the loop nest.
IndexedValue i_res(result), i_arg(arg);
// Initialize result to zero.
IndexHandle m_init;
IndexHandle m_lb_init(v_arg.lb(m_dim));
IndexHandle m_ub_init(v_arg.ub(m_dim));
int64_t m_step = v_arg.step(m_dim);
LoopBuilder(&m_init, m_lb_init, m_ub_init, m_step)([&] { i_res(m_init) = m_lb_init; });
// Main loop nest for argmin
IndexHandle n, m;
IndexHandle n_lb(v_arg.lb(n_dim)), m_lb(v_arg.lb(m_dim));
IndexHandle n_ub(v_arg.ub(n_dim)), m_ub(v_arg.ub(m_dim));
ValueHandle curr_res(res_elem_ty);
int64_t n_step = v_arg.step(n_dim);
LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
LoopBuilder(&m, m_lb, m_ub, m_step)([&] {
curr_res = i_res(m);
i_res(m) = edsc::intrinsics::select(i_arg(n, m) < i_arg(curr_res, m), n, curr_res);
});
});
rewriter.replaceOp(op, {result});
}
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER
}
......
......@@ -30,6 +30,7 @@ public:\
};
DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGArgMinRedOp)
DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp)
......
......@@ -4,6 +4,7 @@
#endif
MLIR_OP(Add)
MLIR_OP(ArgMin)
MLIR_OP(Dot)
// Add new supported ops here
......
......@@ -15,9 +15,11 @@
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
......
......@@ -55,6 +55,25 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial_i32)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment