Commit aeacd4e3 authored by nmostafa's avatar nmostafa

style-apply

parent 33ec9a8b
...@@ -360,7 +360,8 @@ namespace ngraph ...@@ -360,7 +360,8 @@ namespace ngraph
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node); mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()}); mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr); op->setAttr("axes", red_axes_attr);
return result; return result;
} }
...@@ -371,7 +372,8 @@ namespace ngraph ...@@ -371,7 +372,8 @@ namespace ngraph
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node); mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()}); mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr); op->setAttr("axes", red_axes_attr);
return result; return result;
} }
...@@ -382,7 +384,9 @@ namespace ngraph ...@@ -382,7 +384,9 @@ namespace ngraph
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node); auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node); mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); mlir::Operation* op = result->getDefiningOp();
op->setAttr("concatenation_axis", compiler.m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis())); op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
return result; return result;
} }
...@@ -392,9 +396,9 @@ namespace ngraph ...@@ -392,9 +396,9 @@ namespace ngraph
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node); auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node); mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); mlir::Operation* op = result->getDefiningOp();
op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis())); op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result; return result;
} }
} }
} }
...@@ -418,16 +422,11 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) ...@@ -418,16 +422,11 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
} }
return m_builder return m_builder
->create<Op, ->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>(
ArrayRef<mlir::Type>, mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */})
ArrayRef<mlir::Value *>, .getResult();
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context),
res_types,
arg_values, {/* no attrs */}).getResult();
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>}, #define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc" #include "ops_supported.inc"
......
...@@ -98,7 +98,7 @@ namespace ngraph ...@@ -98,7 +98,7 @@ namespace ngraph
void build_ng_dialect(); void build_ng_dialect();
template<typename Op> template <typename Op>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node) static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
{ {
throw std::runtime_error("Unimplemented op '" + ng_node->description() + throw std::runtime_error("Unimplemented op '" + ng_node->description() +
......
...@@ -681,13 +681,12 @@ namespace ...@@ -681,13 +681,12 @@ namespace
paramsUbs.push_back(IndexHandle(vParams.ub(i))); paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsSteps.push_back(vParams.step(i)); paramsSteps.push_back(vParams.step(i));
} }
NGRAPH_CHECK( NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() && paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(), paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params"); "Incorrect loop nest bounds size for gather params");
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank()-1); paramsIVs = IndexHandle::makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs); paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs(); auto indicesLbs = vIndices.getLbs();
...@@ -772,7 +771,7 @@ namespace ...@@ -772,7 +771,7 @@ namespace
} }
#undef REWRITER #undef REWRITER
/// End of pattern matchers /// End of pattern matchers
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
......
...@@ -39,7 +39,6 @@ using namespace ngraph::descriptor; ...@@ -39,7 +39,6 @@ using namespace ngraph::descriptor;
using namespace ngraph::op; using namespace ngraph::op;
using namespace ngraph::pass; using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0; int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment