Commit 67536ddf authored by nmostafa's avatar nmostafa

graph node lowering

parent 3a9de1bb
......@@ -389,7 +389,12 @@ namespace ngraph
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{
return nullptr; //compiler.create_gather(ng_node);
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result;
}
}
}
......
......@@ -174,11 +174,11 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
Type ty = op->input()->getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->input()->getType();
ty = op->indices()->getType();
NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank());
if (op->axis().getSExtValue() >= inputType.getRank())
return op->emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType();
......@@ -195,7 +195,7 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input
if (!resType.isCompatible(inputType))
if (!resType.getRank() == inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type");
return mlir::success();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment