Commit c5b976c8 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

[MLIR] Concat (#3225)

* WIP

* All but two unit tests passing

* Explanatory comment

* Cleanup

* A bit of cleanup stemming from review comments

* Rewrite to use LoopNestBuilder

* Remove unnecessary check from CompiledKernel

* Removed some dead-ish code I missed

* Switch to camelCase in lowerer.cpp

* Fix assignment of resIndexHandles that was triggering an assert

* Add in some safety checks

* dyn_cast -> cast
parent 2c41d422
......@@ -26,6 +26,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
......@@ -363,6 +364,12 @@ namespace ngraph
return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{
return compiler.create_concat(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
{
......@@ -399,6 +406,26 @@ mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
.getResult();
}
mlir::Value* MLIRCompiler::create_concat(const ngraph::Node* ng_node)
{
std::vector<mlir::Value*> arg_values;
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
for (auto& arg : ng_node->get_arguments())
{
auto arg_tensor = arg->get_output_tensor_ptr();
auto arg_v = get_tensor_value(arg_tensor.get()).m_value;
arg_values.push_back(arg_v);
}
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder
->create<mlir::NGConcatOp>(
mlir::UnknownLoc::get(&m_context),
res_type,
arg_values,
m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis()))
.getResult();
}
void MLIRCompiler::create_return()
{
std::vector<mlir::Value*> value_list;
......
......@@ -111,6 +111,10 @@ namespace ngraph
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
// TODO(amprocte): Can we have a create_variadic_op that is able to handle the
// attributes?
mlir::Value* create_concat(const ngraph::Node* ng_node);
template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node);
......
......@@ -113,6 +113,13 @@ mlir::LogicalResult verifyOp(NGDotOp* op)
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGConcatOp* op)
{
// TODO(amprocte): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSelectOp* op)
{
......
......@@ -182,6 +182,18 @@ def NGDotOp : NG_Binary_Op<"dot">
let verifier = [{ return verifyOp(this); }];
}
// TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op
// class, but I'm not sure how to add concatenation_axis into the args if we
// do that.
def NGConcatOp :
NG_OneResult_Op<"concat", [NoSideEffect]>,
Arguments<(ins Variadic<NG_TensorType>:$args, I64Attr:$concatenation_axis)>
{
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
}
class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$operand, I64ArrayAttr:$axes)>
......
......@@ -567,6 +567,86 @@ namespace
return matchSuccess();
}
REWRITER(NGConcatOp)
{
auto concat = cast<NGConcatOp>(op);
auto loc = concat.getLoc();
ScopedContext scope(rewriter, loc);
// Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
auto resultTy = result->getType().cast<MemRefType>();
// Create view to write into result.
MemRefView vRes(result);
auto rank = vRes.rank();
// For each operand, generate a separate loop to copy into the target slice of "result".
// We'll keep track of the slice offsets via concatenation_axis_pos.
auto concatenationAxis = concat.concatenation_axis().getSExtValue();
IndexHandle concatenationAxisPos(index_t(0));
for (auto& operand : operands)
{
NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp");
auto operandTy = result->getType().cast<MemRefType>();
// Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
// loops of this form:
//
// for i_0 := 0 to operand.dims[0]:
// for i_1 := 0 to operand.dims[1]:
// ...
// for i_(r-2) := 0 to operand.dims[r-2]:
// for i_(r-1) := 0 to operand.dims[r-1]:
// result[i_0][i_1]...
// [i_(A-1)][i_A + concatenationAxisPos][i_(A+1)]...
// [i_(r-2)][i_(r-1)]
// :=
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
MemRefView vOperand(operand);
NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch");
llvm::SmallVector<ValueHandle, 5> indexVars;
llvm::SmallVector<ValueHandle*, 5> indexVarPtrs;
llvm::SmallVector<ValueHandle, 5> indexVarLbs;
llvm::SmallVector<ValueHandle, 5> indexVarUbs;
llvm::SmallVector<int64_t, 5> indexVarSteps;
for (int i = 0; i < rank; i++)
{
indexVars.push_back(IndexHandle());
indexVarPtrs.push_back(&(indexVars.back()));
indexVarLbs.push_back(vOperand.lb(i));
indexVarUbs.push_back(vOperand.ub(i));
indexVarSteps.push_back(vOperand.step(i));
}
LoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] {
IndexedValue ivRes(result);
IndexedValue ivOperand(operand);
// On the LHS of the assignment, adjust the index for the concatenation axis.
llvm::SmallVector<ValueHandle, 5> resIndexHandles;
for (int i = 0; i < rank; i++)
{
resIndexHandles.push_back(i == concatenationAxis
? indexVars[i] + concatenationAxisPos
: indexVars[i]);
}
ivRes(resIndexHandles) = ivOperand(indexVars);
});
// Move up concatenation_axis_pos for the next operand.
concatenationAxisPos = concatenationAxisPos + vOperand.ub(concatenationAxis);
}
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
......
......@@ -26,6 +26,7 @@
MLIR_OP(NGAddOp)
MLIR_OP(NGArgMaxRedOp)
MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGConcatOp)
MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp)
MLIR_OP(NGGreaterOp)
......
......@@ -8,6 +8,7 @@ MLIR_OP(ArgMin)
MLIR_OP(ArgMax)
MLIR_OP(Divide)
MLIR_OP(Dot)
MLIR_OP(Concat)
MLIR_OP(Greater)
MLIR_OP(Less)
MLIR_OP(Maximum)
......
......@@ -21,6 +21,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
......
......@@ -71,15 +71,6 @@ ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
constructor_validate_and_infer_types();
set_output_size(m_output_nodes.size());
auto ref = node_list.at(0);
for (auto n : node_list)
{
if (n->get_shape() != ref->get_shape() || n->get_element_type() != ref->get_element_type())
{
throw ngraph_error("types and shapes of the nodes in node_list are different");
}
}
for (size_t i = 0; i < outputs.size(); ++i)
{
auto& o = outputs.at(i);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment