Commit ca9adeb1 authored by Ewa Tusień's avatar Ewa Tusień Committed by Sang Ik Lee

Update ONNX importer to use v1 version of Softmax (#3894)

* Added downgrade pass for Softmax.

* Updated Softmax op to v1.

* Created vector with a right capacity.

* Include numeric header to enable std::iota function

* Removed unused numeric header from the old file

* Fix includes style
parent 509d1a7f
......@@ -14,9 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include <memory>
#include "exceptions.hpp"
#include "ngraph/op/softmax.hpp"
#include "softmax.hpp"
#include "utils/common.hpp"
......@@ -38,10 +37,7 @@ namespace ngraph
int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - valid_axis);
std::iota(std::begin(axes), std::end(axes), valid_axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
return {std::make_shared<ngraph::op::v1::Softmax>(data, valid_axis)};
}
} // namespace set_1
......
......@@ -16,9 +16,8 @@
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include <cstdint>
#include <numeric>
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
......@@ -48,6 +49,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
......@@ -609,6 +611,20 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
replace_node(node, replacement_node);
break;
}
case OP_TYPEID::Softmax:
{
auto tmp = as_type_ptr<op::v1::Softmax>(node);
auto axis = tmp->get_axis();
auto data = node->input(0);
auto data_shape = data.get_shape();
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
auto replacement_node =
make_shared<op::v0::Softmax>(node->input(0).get_source_output(), axes);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Sum:
{
auto tmp = as_type_ptr<op::v1::ReduceSum>(node);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment