Commit 67536ddf authored by nmostafa's avatar nmostafa

graph node lowering

parent 3a9de1bb
...@@ -389,7 +389,12 @@ namespace ngraph ...@@ -389,7 +389,12 @@ namespace ngraph
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{ {
return nullptr; //compiler.create_gather(ng_node); auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result;
} }
} }
} }
......
...@@ -174,11 +174,11 @@ mlir::LogicalResult verifyOp(NGGatherOp* op) ...@@ -174,11 +174,11 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
Type ty = op->input()->getType(); Type ty = op->input()->getType();
NGTensorType inputType = ty.cast<NGTensorType>(); NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->input()->getType(); ty = op->indices()->getType();
NGTensorType indicesType = ty.cast<NGTensorType>(); NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank // ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank()); if (op->axis().getSExtValue() >= inputType.getRank())
return op->emitOpError("Gather axis is larger than input rank"); return op->emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType(); ty = indicesType.getElementType();
...@@ -195,7 +195,7 @@ mlir::LogicalResult verifyOp(NGGatherOp* op) ...@@ -195,7 +195,7 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
NGTensorType resType = r0.cast<NGTensorType>(); NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input // ensure result is compatible with input
if (!resType.isCompatible(inputType)) if (!resType.getRank() == inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type"); return op->emitOpError("Incompatible result shape and/or type");
return mlir::success(); return mlir::success();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment