Commit a7706e98 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Sang Ik Lee

[ONNX] Use v1 in ArgMax, ArgMin, Hardmax ops (#4175)

* first working version

* Added using v1 for ArgMin and ArgMax

* code refactor

* Code review remarks introduced

* Code review remarks introduced

* fix style

* revert fix style

* empty
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent c967e792
......@@ -137,7 +137,7 @@ shared_ptr<Node> builder::squeeze(const Output<Node>& value, vector<size_t> axes
Shape in_shape{value.get_shape()};
for (size_t idx = 0; idx < axes.size(); ++idx)
{
in_shape.at(idx) = 0;
in_shape.at(axes.at(idx)) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
......@@ -228,3 +228,26 @@ shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t
output_shape.insert(empty_axis_it, 1);
return builder::opset1::reshape(value, output_shape);
}
shared_ptr<Node> builder::opset1::squeeze(const Output<Node>& value, vector<int64_t> axes)
{
if (axes.empty())
{
return value.get_node_shared_ptr();
}
Shape in_shape{value.get_shape()};
for (size_t idx = 0; idx < axes.size(); ++idx)
{
in_shape.at(axes.at(idx)) = 0;
}
Shape output_shape;
for (auto axis : in_shape)
{
if (axis != 0)
{
output_shape.push_back(axis);
}
}
return builder::opset1::reshape(value, output_shape);
}
......@@ -148,6 +148,15 @@ namespace ngraph
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
/// \brief Remove empty axes from input tensor.
///
/// \param[in] value The value to be squeezed.
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::int64_t> axes = {0});
}
} // namespace builder
} // namespace ngraph
......@@ -214,6 +214,8 @@ add_library(onnx_import STATIC
op/xor.hpp
ops_bridge.cpp
ops_bridge.hpp
utils/arg_min_max_factory.cpp
utils/arg_min_max_factory.hpp
utils/common.cpp
utils/common.hpp
utils/convpool.cpp
......
......@@ -14,11 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "argmax.hpp"
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/reduction.hpp"
#include "utils/arg_min_max_factory.hpp"
namespace ngraph
{
......@@ -30,7 +27,8 @@ namespace ngraph
{
NodeVector argmax(const Node& node)
{
return {reduction::make_ng_index_reduction_op<ngraph::opset0::ArgMax>(node)};
const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_max()};
}
} // namespace set_1
......
......@@ -14,11 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include "argmin.hpp"
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/reduction.hpp"
#include "utils/arg_min_max_factory.hpp"
namespace ngraph
{
......@@ -30,7 +27,8 @@ namespace ngraph
{
NodeVector argmin(const Node& node)
{
return {reduction::make_ng_index_reduction_op<ngraph::opset0::ArgMin>(node)};
const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_min()};
}
} // namespace set_1
......
......@@ -17,6 +17,8 @@
#include "hardmax.hpp"
#include "exceptions.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp"
......@@ -33,7 +35,7 @@ namespace ngraph
{
const auto input = node.get_ng_inputs().at(0);
const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
const auto normalized_axis =
ngraph::normalize_axis(node.get_description(), axis, input_shape.size());
......@@ -42,19 +44,31 @@ namespace ngraph
const auto coerced_tensor =
ngraph::builder::opset1::flatten(input, normalized_axis);
const auto& coerced_shape = coerced_tensor->get_shape();
const auto row_size = static_cast<int64_t>(coerced_shape.at(1));
const std::shared_ptr<ngraph::Node> argmax_2d =
std::make_shared<ngraph::opset0::ArgMax>(coerced_tensor, 1, element::i64);
const auto indices_axis = 1;
const auto max_indices = std::make_shared<opset0::GetOutputElement>(
std::make_shared<default_opset::TopK>(
coerced_tensor,
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1}),
indices_axis,
default_opset::TopK::Mode::MAX,
default_opset::TopK::SortType::NONE),
1);
std::shared_ptr<ngraph::Node> eye_matrix =
common::square_identity(coerced_shape.at(1), input->get_element_type());
const auto depth =
ngraph::op::Constant::create(ngraph::element::i64, Shape{}, {row_size});
const auto on_value =
ngraph::op::Constant::create(ngraph::element::i64, Shape{}, {1});
const auto off_value =
ngraph::op::Constant::create(ngraph::element::i64, Shape{}, {0});
// the results are elements of the eye_matrix indexed by argmax_2d values
// in other words: eye_matrix[argmax_2d]
auto results =
std::make_shared<ngraph::opset0::EmbeddingLookup>(argmax_2d, eye_matrix);
const auto results = std::make_shared<default_opset::OneHot>(
max_indices, depth, on_value, off_value, indices_axis);
const auto converted_results = std::make_shared<default_opset::Convert>(
results, input->get_element_type());
return {ngraph::builder::opset1::reshape(results, input_shape)};
return {ngraph::builder::opset1::reshape(converted_results, input_shape)};
}
} // namespace set_1
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "utils/arg_min_max_factory.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/validation_util.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace utils
{
ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
: m_keep_dims{node.get_attribute_value<std::int64_t>("keepdims", 1)}
{
m_input_node = node.get_ng_inputs().at(0);
const auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
m_normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, m_input_node->get_shape().size());
}
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
{
return make_topk_subgraph(default_opset::TopK::Mode::MAX);
}
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_min() const
{
return make_topk_subgraph(default_opset::TopK::Mode::MIN);
}
std::shared_ptr<ngraph::Node>
ArgMinMaxFactory::make_topk_subgraph(default_opset::TopK::Mode mode) const
{
const auto k_node =
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1});
const auto topk =
std::make_shared<default_opset::TopK>(m_input_node,
k_node,
m_normalized_axis,
mode,
default_opset::TopK::SortType::NONE);
const auto indices = std::make_shared<ngraph::opset0::GetOutputElement>(topk, 1);
if (m_keep_dims == 0)
{
const auto reshaped_indices =
ngraph::builder::opset1::squeeze(indices, {m_normalized_axis});
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
}
return std::make_shared<default_opset::Convert>(indices, element::i64);
}
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstdint>
#include <memory>
#include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace utils
{
/// \brief Factory class which generates sub-graphs for ONNX ArgMin, ArgMax ops.
class ArgMinMaxFactory
{
public:
explicit ArgMinMaxFactory(const Node& node);
virtual ~ArgMinMaxFactory() = default;
/// \brief Creates ArgMax ONNX operation.
/// \return Sub-graph representing ArgMax op.
std::shared_ptr<ngraph::Node> make_arg_max() const;
/// \brief Creates ArgMin ONNX operation.
/// \return Sub-graph representing ArgMin op.
std::shared_ptr<ngraph::Node> make_arg_min() const;
private:
std::shared_ptr<ngraph::Node>
make_topk_subgraph(default_opset::TopK::Mode mode) const;
const std::int64_t m_keep_dims;
std::shared_ptr<ngraph::Node> m_input_node;
std::int64_t m_normalized_axis;
};
} // namespace arg
} // namespace onnx_import
} // namespace ngraph
......@@ -78,37 +78,6 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& ng_input,
RuntimeReductionFunction reduction_function);
template <class IndexReduction>
std::shared_ptr<ngraph::Node> make_ng_index_reduction_op(const Node& node)
{
auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0);
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, input_node->get_shape().size());
auto op_node =
std::make_shared<IndexReduction>(input_node, normalized_axis, element::i64);
if (keepdims == 0)
{
return std::move(op_node);
}
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape();
output_shape.at(normalized_axis) = 1;
auto reshape_node = builder::opset1::reshape(op_node, output_shape);
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
auto reconvert_node =
std::make_shared<ngraph::op::Convert>(reshape_node, element::i64);
return std::move(reconvert_node);
}
} // namespace reduction
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment