Commit 5d116018 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

Use v1::Gather in ONNX Importer (#4037)

parent 8bd37070
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/opsets/opset1.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
namespace ngraph namespace ngraph
...@@ -39,7 +39,10 @@ namespace ngraph ...@@ -39,7 +39,10 @@ namespace ngraph
auto axis = node.get_attribute_value<int64_t>("axis", 0); auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto valid_axis = common::validate_axis(node, axis, data->get_shape().size()); auto valid_axis = common::validate_axis(node, axis, data->get_shape().size());
return {std::make_shared<ngraph::op::Gather>(data, indices, valid_axis)}; return {std::make_shared<opset1::Gather>(
data,
indices,
opset1::Constant::create(element::i64, Shape{}, {valid_axis}))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -231,6 +231,27 @@ namespace ...@@ -231,6 +231,27 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::Gather> node)
{
auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
NGRAPH_CHECK(axis_node,
"Unable to convert Gather:v1 to Gather:v0 if axis is not constant. Node: ",
*node);
NGRAPH_CHECK(
axis_node->get_element_type() == element::i64,
"Unable to convert Gather:v1 to Gather:v0 with axis other type than int64. Node: ",
*node);
int64_t axis = axis_node->get_vector<int64_t>()[0];
auto replacement_node =
make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::GenerateMask> node) bool op_cast(shared_ptr<op::v1::GenerateMask> node)
{ {
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
...@@ -707,7 +728,7 @@ namespace ...@@ -707,7 +728,7 @@ namespace
}; };
return dispatch_map; return dispatch_map;
} }
} } // namespace
bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment