Unverified Commit c04c0349 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3263 from NervanaSystems/nmostafa/gather

[MLIR] Enable Gather Op
parents d34fb157 e12aa4ca
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
...@@ -63,6 +64,7 @@ ...@@ -63,6 +64,7 @@
using llvm::SmallVector; using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::make_unique; using llvm::make_unique;
using llvm::ArrayRef;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
...@@ -282,11 +284,20 @@ void MLIRCompiler::build_ng_dialect() ...@@ -282,11 +284,20 @@ void MLIRCompiler::build_ng_dialect()
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} + throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"}; np->description() + "' operation"};
} }
mlir::Value* mlir_value = it->second(*this, np.get()); mlir::Operation* op = it->second(*this, np.get());
// builders that have multiple result values will update the value map, and set their ret values to null // This assumes simple 1:1 mapping between output edges and generated MLIR op results
if (mlir_value) // If the mapping is more complex, the create_op helper can return null operation
// and handles populating the value map itself
if (op)
{ {
update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value); for (auto i = 0; i < op->getNumResults(); i++)
{
mlir::Value* result = op->getResult(i);
if (result)
{
update_tensor_value(np->get_output_tensor_ptr(i).get(), result);
}
}
} }
} }
create_return(); create_return();
...@@ -299,133 +310,125 @@ namespace ngraph ...@@ -299,133 +310,125 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{ {
return compiler.create_binary_op<mlir::NGAddOp>(ng_node); return compiler.create_generic_op<mlir::NGAddOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{ {
return compiler.create_binary_op<mlir::NGSubOp>(ng_node); return compiler.create_generic_op<mlir::NGSubOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{ {
return compiler.create_binary_op<mlir::NGMulOp>(ng_node); return compiler.create_generic_op<mlir::NGMulOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{ {
return compiler.create_binary_op<mlir::NGDivOp>(ng_node); return compiler.create_generic_op<mlir::NGDivOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{ {
return compiler.create_binary_op<mlir::NGGreaterOp>(ng_node); return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{ {
return compiler.create_binary_op<mlir::NGLessOp>(ng_node); return compiler.create_generic_op<mlir::NGLessOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{ {
return compiler.create_binary_op<mlir::NGMaxOp>(ng_node); return compiler.create_generic_op<mlir::NGMaxOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{ {
return compiler.create_binary_op<mlir::NGMinOp>(ng_node); return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
return compiler.create_binary_op<mlir::NGDotOp>(ng_node); return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{ {
return compiler.create_concat(ng_node); auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
return op;
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{ {
return compiler.create_unary_op<mlir::NGReluOp>(ng_node); auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return op;
} }
} }
} }
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ template <typename Op>
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>}, mlir::Operation* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
#include "ops_supported.inc"
};
template <typename UnaryOp>
mlir::Value* MLIRCompiler::create_unary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<UnaryOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v)
.getResult();
}
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v, rhs_v)
.getResult();
}
mlir::Value* MLIRCompiler::create_concat(const ngraph::Node* ng_node)
{ {
std::vector<mlir::Value*> arg_values; std::vector<mlir::Value*> arg_values;
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node); std::vector<mlir::Type> res_types;
for (auto& arg : ng_node->get_arguments()) for (auto& arg : ng_node->get_arguments())
{ {
auto arg_tensor = arg->get_output_tensor_ptr(); auto arg_tensor = arg->get_output_tensor_ptr();
auto arg_v = get_tensor_value(arg_tensor.get()).m_value; auto arg_v = get_tensor_value(arg_tensor.get()).m_value;
arg_values.push_back(arg_v); arg_values.push_back(arg_v);
} }
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder for (auto& output : ng_node->outputs())
->create<mlir::NGConcatOp>( {
mlir::UnknownLoc::get(&m_context), res_types.push_back(get_mlir_type(output.get_tensor_ptr().get()));
res_type, }
arg_values,
m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis())) return (m_builder->create<Op,
.getResult(); ArrayRef<mlir::Type>,
ArrayRef<mlir::Value*>,
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */}))
.getOperation();
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
};
void MLIRCompiler::create_return() void MLIRCompiler::create_return()
{ {
std::vector<mlir::Value*> value_list; std::vector<mlir::Value*> value_list;
...@@ -437,21 +440,16 @@ void MLIRCompiler::create_return() ...@@ -437,21 +440,16 @@ void MLIRCompiler::create_return()
} }
template <typename RedOp> template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto op = create_generic_op<RedOp>(ng_node);
auto arg = idx_red->get_argument(0); mlir::ArrayAttr red_axes_attr =
size_t red_axis = idx_red->get_reduction_axis(); m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
mlir::Value* arg_val = get_tensor_value(arg->get_output_tensor_ptr().get()).m_value; return op;
mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis});
return m_builder
->create<RedOp>(
mlir::UnknownLoc::get(&m_context), get_mlir_type(ng_node), arg_val, red_axes_attr)
.getResult();
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function. // helpers to be used inside the function.
void MLIRCompiler::bind_arguments() void MLIRCompiler::bind_arguments()
......
...@@ -98,25 +98,21 @@ namespace ngraph ...@@ -98,25 +98,21 @@ namespace ngraph
void build_ng_dialect(); void build_ng_dialect();
template <typename OP> template <typename Op>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node) static mlir::Operation* create_op(MLIRCompiler& compiler,
const ngraph::Node* ng_node)
{ {
throw std::runtime_error("Unimplemented op '" + ng_node->description() + throw std::runtime_error("Unimplemented op '" + ng_node->description() +
"' in MLIR Compiler"); "' in MLIR Compiler");
} }
template <typename UnaryOp> // Generic op lowerer to ng dialect.
mlir::Value* create_unary_op(const ngraph::Node* ng_node); // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op>
template <typename BinOp> mlir::Operation* create_generic_op(const ngraph::Node* ng_node);
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
// TODO(amprocte): Can we have a create_variadic_op that is able to handle the
// attributes?
mlir::Value* create_concat(const ngraph::Node* ng_node);
template <typename RedOp> template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node); mlir::Operation* create_index_reduction(const ngraph::Node* ng_node);
void create_return(); void create_return();
...@@ -150,7 +146,7 @@ namespace ngraph ...@@ -150,7 +146,7 @@ namespace ngraph
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>; using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>; using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
using MLIRCompOpFunction = using MLIRCompOpFunction =
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>; std::function<mlir::Operation*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>; using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
......
...@@ -168,6 +168,39 @@ static mlir::LogicalResult verifyCmpOp(T* op) ...@@ -168,6 +168,39 @@ static mlir::LogicalResult verifyCmpOp(T* op)
return mlir::success(); return mlir::success();
} }
template <>
mlir::LogicalResult verifyOp(NGGatherOp* op)
{
Type ty = op->params()->getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType();
NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank())
return op->emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType();
// ensure indices are I32 or I64
if (!ty.isa<NGIntegerType>())
return op->emitOpError("Indices tensor is not of Integer type");
NGIntegerType indicesEltType = ty.cast<NGIntegerType>();
if (!indicesEltType.isInt32() && !indicesEltType.isInt64())
return op->emitOpError("Indices tensor is not of I32 or I64 type");
mlir::Type r0 = op->res()->getType();
NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input
if (!resType.getRank() == inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type");
return mlir::success();
}
namespace mlir namespace mlir
{ {
#define GET_OP_CLASSES #define GET_OP_CLASSES
......
...@@ -200,7 +200,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -200,7 +200,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
{ {
let summary = "Base class for reduction operations that perform a reduction " let summary = "Base class for reduction operations that perform a reduction "
"across the axes of a single tensor."; "across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes."; let description = [{Axes are represented as an array of I64 attributes.}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
...@@ -257,6 +257,24 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red"> ...@@ -257,6 +257,24 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
let verifier = [{ return verifyLogicalReductionOp(this); }]; let verifier = [{ return verifyLogicalReductionOp(this); }];
} }
// Gather
def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$params, NG_TensorType:$indices, I64Attr:$axis)>
{
let summary = "Gather slices from params along the specified axis according to indices";
let description = [{
Gather slices from axis of params according to indices
params The tensor from which slices are gathered
indices Index tensor. Data type must be `element::i32` or `element::i64`
axis Axis in params to gather
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
}
// Terminator Ops // Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">; def NGReturnOp : NG_Terminator_Op<"return">;
......
...@@ -199,6 +199,7 @@ namespace mlir ...@@ -199,6 +199,7 @@ namespace mlir
} }
Shape getShape() const { return m_shape; } Shape getShape() const { return m_shape; }
int64_t getRank() const { return m_shape.size(); }
EltType getElementType() const { return m_eltType; } EltType getElementType() const { return m_eltType; }
private: private:
NGTensorTypeStorage(EltType eltType, Shape shape) NGTensorTypeStorage(EltType eltType, Shape shape)
......
...@@ -646,6 +646,123 @@ namespace ...@@ -646,6 +646,123 @@ namespace
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGGatherOp)
{
auto gatherOp = cast<NGGatherOp>(op);
auto loc = gatherOp.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
auto resultTy = result->getType().cast<MemRefType>();
Value* params = operands[0];
Value* indices = operands[1];
auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
SmallVector<IndexHandle, 4> paramsIVs;
SmallVector<int64_t, 4> paramsSteps;
SmallVector<ValueHandle*, 4> paramsIVPtrs;
for (auto i = 0; i < vParams.rank(); i++)
{
// skip gather axis
if (i == axis)
continue;
paramsLbs.push_back(IndexHandle(vParams.lb(i)));
paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsSteps.push_back(vParams.step(i));
}
NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps();
auto indicesIVs = IndexHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = IndexHandle::makeIndexHandlePointers(indicesIVs);
SmallVector<IndexHandle, 8> paramsIndices, resIndices;
// Make sure we are going to create loops
NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");
// Let params rank : N
// Let indices rank : M
// Let axis be A
// Generate
// params loops
// for P_0: 0 -> params.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ...
// for P_(A-1):0 -> params.dim[A-1]
// for P_(A+1):0 -> params.dim[A+1]
// ...
// for P_(N-1):0 -> params.dim[N-1]
// indices loops
// for I_0:0 -> indices.dim[0]
// ...
// for I_(M-1):0 -> indices.dim[M-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
// Load axis value from indices array and cast it to Index Type
ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
// construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank(); i++)
{
if (i == axis)
{
paramsIndices.push_back(IndexHandle(axisIdx));
}
else
{
paramsIndices.push_back(paramsIVs[j++]);
}
}
// construct indices for result
// [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank() + vIndices.rank() - 1;)
{
if (i == axis && indicesIVs.size() > 0)
{
resIndices.append(indicesIVs.begin(), indicesIVs.end());
i += indicesIVs.size();
}
else
{
resIndices.push_back(paramsIVs[j++]);
i++;
}
}
// Store into result
iRes(resIndices) = iParams(paramsIndices);
});
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp) REWRITER(NGReturnOp)
{ {
rewriter.replaceOpWithNewOp<ReturnOp>(op); rewriter.replaceOpWithNewOp<ReturnOp>(op);
...@@ -653,7 +770,7 @@ namespace ...@@ -653,7 +770,7 @@ namespace
} }
#undef REWRITER #undef REWRITER
/// End of pattern matchers
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
......
...@@ -29,6 +29,7 @@ MLIR_OP(NGArgMinRedOp) ...@@ -29,6 +29,7 @@ MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGConcatOp) MLIR_OP(NGConcatOp)
MLIR_OP(NGDivOp) MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp) MLIR_OP(NGDotOp)
MLIR_OP(NGGatherOp)
MLIR_OP(NGGreaterOp) MLIR_OP(NGGreaterOp)
MLIR_OP(NGLessOp) MLIR_OP(NGLessOp)
MLIR_OP(NGMulOp) MLIR_OP(NGMulOp)
......
...@@ -9,6 +9,7 @@ MLIR_OP(ArgMax) ...@@ -9,6 +9,7 @@ MLIR_OP(ArgMax)
MLIR_OP(Divide) MLIR_OP(Divide)
MLIR_OP(Dot) MLIR_OP(Dot)
MLIR_OP(Concat) MLIR_OP(Concat)
MLIR_OP(Gather)
MLIR_OP(Greater) MLIR_OP(Greater)
MLIR_OP(Less) MLIR_OP(Less)
MLIR_OP(Maximum) MLIR_OP(Maximum)
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment