Commit aeacd4e3 authored by nmostafa's avatar nmostafa

style-apply

parent 33ec9a8b
......@@ -348,7 +348,7 @@ namespace ngraph
return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
}
template <>
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
......@@ -360,7 +360,8 @@ namespace ngraph
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
}
......@@ -371,7 +372,8 @@ namespace ngraph
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
}
......@@ -382,7 +384,9 @@ namespace ngraph
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("concatenation_axis", compiler.m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis()));
op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
return result;
}
......@@ -392,9 +396,9 @@ namespace ngraph
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result;
}
}
}
......@@ -418,16 +422,11 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
}
return m_builder
->create<Op,
ArrayRef<mlir::Type>,
ArrayRef<mlir::Value *>,
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context),
res_types,
arg_values, {/* no attrs */}).getResult();
->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */})
.getResult();
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
......
......@@ -98,15 +98,15 @@ namespace ngraph
void build_ng_dialect();
template<typename Op>
template <typename Op>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
{
throw std::runtime_error("Unimplemented op '" + ng_node->description() +
"' in MLIR Compiler");
}
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename OP>
mlir::Value* create_generic_op(const ngraph::Node* ng_node);
......
......@@ -173,7 +173,7 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
{
Type ty = op->params()->getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType();
NGTensorType indicesType = ty.cast<NGTensorType>();
......@@ -190,10 +190,10 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
NGIntegerType indicesEltType = ty.cast<NGIntegerType>();
if (!indicesEltType.isInt32() && !indicesEltType.isInt64())
return op->emitOpError("Indices tensor is not of I32 or I64 type");
mlir::Type r0 = op->res()->getType();
NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input
if (!resType.getRank() == inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type");
......
......@@ -681,19 +681,18 @@ namespace
paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsSteps.push_back(vParams.step(i));
}
NGRAPH_CHECK(
paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank()-1);
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps();
auto indicesIVs = IndexHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = IndexHandle::makeIndexHandlePointers(indicesIVs);
......@@ -719,7 +718,7 @@ namespace
// for I_0:0 -> indices.dim[0]
// ...
// for I_(M-1):0 -> indices.dim[M-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
......@@ -735,7 +734,7 @@ namespace
{
paramsIndices.push_back(IndexHandle(axisIdx));
}
else
else
{
paramsIndices.push_back(paramsIVs[j++]);
}
......@@ -772,7 +771,7 @@ namespace
}
#undef REWRITER
/// End of pattern matchers
/// End of pattern matchers
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
......
......@@ -39,7 +39,6 @@ using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment