Commit ca9adeb1 authored by Ewa Tusień's avatar Ewa Tusień Committed by Sang Ik Lee

Update ONNX importer to use v1 version of Softmax (#3894)

* Added downgrade pass for Softmax.

* Updated Softmax op to v1.

* Created vector with a right capacity.

* Include numeric header to enable std::iota function

* Removed unused numeric header from the old file

* Fix includes style
parent 509d1a7f
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <numeric> #include <memory>
#include "exceptions.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "softmax.hpp" #include "softmax.hpp"
#include "utils/common.hpp" #include "utils/common.hpp"
...@@ -38,10 +37,7 @@ namespace ngraph ...@@ -38,10 +37,7 @@ namespace ngraph
int axis = node.get_attribute_value<int64_t>("axis", 1); int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size()); std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());
// create vector of capacity data_dimensions - axis_divider position return {std::make_shared<ngraph::op::v1::Softmax>(data, valid_axis)};
std::vector<size_t> axes(data_shape.size() - valid_axis);
std::iota(std::begin(axes), std::end(axes), valid_axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
#pragma once #pragma once
#include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include <cstdint> #include <cstdint>
#include <numeric>
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
...@@ -48,6 +49,7 @@ ...@@ -48,6 +49,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/strided_slice.hpp" #include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
...@@ -609,6 +611,20 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -609,6 +611,20 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
replace_node(node, replacement_node); replace_node(node, replacement_node);
break; break;
} }
case OP_TYPEID::Softmax:
{
auto tmp = as_type_ptr<op::v1::Softmax>(node);
auto axis = tmp->get_axis();
auto data = node->input(0);
auto data_shape = data.get_shape();
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
auto replacement_node =
make_shared<op::v0::Softmax>(node->input(0).get_source_output(), axes);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
auto tmp = as_type_ptr<op::v1::ReduceSum>(node); auto tmp = as_type_ptr<op::v1::ReduceSum>(node);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment