Commit d357cb92 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

[ONNX] GlobalLpPool operator (#2476)

* Utility functions for calculating Lp norm.

* Use functor object as a reduction operation.

* Use new api of make_ng_reduction_op.

* Use utility norm functions for reduction operations.

* Onnx GlobalLpPool operator.

* Ensure correct shapes after lp_norm reduction.

* Remove unused function overload.

* Fix shapes and tensor types.

* Unit tests.

* Update comments.

* Update supported ops status table.

* Fix: take absolute value of input tensor elements.

* UT: with odd value p-norm.

* Fix: move taking abs value into respective lp-norm functions.

* Fix clang -Wdocumentation-unknown-command error.

* Update supported op status table with new Jira ticket for Erf op.

* Update supported_ops status table.

* Update interface of make_ng_reduction_op - accept std::function object.

* Update to use new make_ng_reduction_op api.

* Remove unused header.

* Fix errors on CentOS.
parent c2974ac2
......@@ -98,6 +98,8 @@ add_library(onnx_import STATIC
op/less.hpp
op/log.hpp
op/log_softmax.hpp
op/lp_pool.cpp
op/lp_pool.hpp
op/lrn.cpp
op/lrn.hpp
op/lstm.cpp
......@@ -174,6 +176,9 @@ add_library(onnx_import STATIC
utils/common.hpp
utils/convpool.cpp
utils/convpool.hpp
utils/norm.cpp
utils/norm.hpp
utils/reduction.cpp
utils/reduction.hpp
utils/reshape.cpp
utils/reshape.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <cstddef>
#include <cstdint>
#include "exceptions.hpp"
#include "lp_pool.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/reshape.hpp"
#include "utils/common.hpp"
#include "utils/norm.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector global_lp_pool(const Node& node)
{
std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::size_t channel_axis{1};
std::size_t channels_count = data->get_shape().at(channel_axis);
std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute.";
NodeVector slices = reshape::split(data, channels_count, channel_axis);
for (auto& slice : slices)
{
const Shape& orig_shape = data->get_shape();
// all dimensions except spatial/feature
AxisSet reduction_axes{
common::get_monotonic_range<std::size_t>(orig_shape.size(), 2)};
slice =
norm::lp_norm(slice, reduction_axes, static_cast<std::size_t>(p_norm));
// output shape is all ones except N channel
Shape output_shape(orig_shape.size(), 1);
output_shape.at(0) = orig_shape.at(0);
slice = std::make_shared<ngraph::op::Reshape>(
slice,
reshape::get_default_axis_vector(slice->get_shape().size()),
output_shape);
}
return {std::make_shared<ngraph::op::Concat>(slices, channel_axis)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Creates nGraph node representing ONNX GlobalLpPool operator.
///
/// \note This functions calculates "entrywise" norms in spatial/feature
/// dimensions. That is it treats matrix/tensor in spatial/feature
/// dimensions as a vector and applies apropriate norm on it. The
/// result is a scalar.
///
/// Suppose A contains spatial dimensions of input tensor, then
/// for matrix A we have p-norm defined as following double sum over
/// all elements:
/// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
///
/// \param[in] node The input ONNX node representing this operation.
///
/// \return Vector of nodes containting resulting nGraph nodes.
///
NodeVector global_lp_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -23,7 +23,6 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/shape.hpp"
#include "reduce.hpp"
#include "utils/broadcasting.hpp"
namespace ngraph
{
......@@ -44,16 +43,21 @@ namespace ngraph
[&input_shape](const std::size_t& a, const std::size_t& b) {
return a * input_shape.at(b);
});
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
node, node.get_ng_inputs().at(0));
auto const_node = std::make_shared<ngraph::op::Constant>(
auto sum_node = std::shared_ptr<ngraph::Node>{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Sum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
auto const_node = ngraph::op::Constant::create(
sum_node->get_element_type(),
Shape{},
std::vector<std::size_t>{elem_count_product});
sum_node->get_shape(),
std::vector<std::size_t>(shape_size(sum_node->get_shape()),
elem_count_product));
auto broadcasted_const_node =
make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
return {std::make_shared<ngraph::op::Divide>(sum_node, const_node)};
}
} // namespace set_1
......
......@@ -27,9 +27,9 @@
#include "ngraph/op/min.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/sum.hpp"
#include "utils/broadcasting.hpp"
#include "utils/norm.hpp"
#include "utils/reduction.hpp"
namespace ngraph
......@@ -53,8 +53,12 @@ namespace ngraph
///
inline NodeVector reduce_log_sum(const Node& node)
{
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
node, node.get_ng_inputs().at(0));
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Sum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
return {std::make_shared<ngraph::op::Log>(sum_node)};
}
......@@ -72,8 +76,12 @@ namespace ngraph
inline NodeVector reduce_log_sum_exp(const Node& node)
{
auto exp_node = std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0));
auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, exp_node);
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<ngraph::op::Sum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
return {std::make_shared<ngraph::op::Log>(sum_node)};
}
......@@ -90,8 +98,8 @@ namespace ngraph
///
inline NodeVector reduce_l1(const Node& node)
{
auto abs_node = std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0));
return {reduction::make_ng_reduction_op<ngraph::op::Sum>(node, abs_node)};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), norm::l1_norm)};
}
/// \brief Compute the L2 norm of the input tensor's element along the provided axes.
......@@ -107,12 +115,8 @@ namespace ngraph
///
inline NodeVector reduce_l2(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto square_node =
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(0));
auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node);
return {std::make_shared<ngraph::op::Sqrt>(sum_node)};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), norm::l2_norm)};
}
/// \brief Compute the maximum value of the input tensor's elements along the provided axes.
......@@ -128,8 +132,12 @@ namespace ngraph
///
inline NodeVector reduce_max(const Node& node)
{
return {reduction::make_ng_reduction_op<ngraph::op::Max>(
node, node.get_ng_inputs().at(0))};
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Max,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
}
/// \brief Compute the mean value of the input tensor's elements along the provided axes.
......@@ -158,8 +166,12 @@ namespace ngraph
///
inline NodeVector reduce_min(const Node& node)
{
return {reduction::make_ng_reduction_op<ngraph::op::Min>(
node, node.get_ng_inputs().at(0))};
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Min,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
}
/// \brief Compute the product of the input tensor's elements along the provided axes.
......@@ -175,8 +187,12 @@ namespace ngraph
///
inline NodeVector reduce_prod(const Node& node)
{
return {reduction::make_ng_reduction_op<ngraph::op::Product>(
node, node.get_ng_inputs().at(0))};
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Product,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
}
/// \brief Compute the sum of the input tensor's elements along the provided axes.
......@@ -192,8 +208,12 @@ namespace ngraph
///
inline NodeVector reduce_sum(const Node& node)
{
return {reduction::make_ng_reduction_op<ngraph::op::Sum>(
node, node.get_ng_inputs().at(0))};
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::op::Sum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
}
/// \brief Compute the sum square of the input tensor's element along the provided axes.
......@@ -209,10 +229,14 @@ namespace ngraph
///
inline NodeVector reduce_sum_square(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto square_node =
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(0));
return {reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node)};
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = input * input;
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<ngraph::op::Sum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
}
} // namespace set_1
......
......@@ -47,7 +47,8 @@ opset versions starting from `1` to `6` and to the latest opset version.
| Flatten | 1-9- |
| Floor | 1-6- |
| Gemm | 1-6-7-9 |
| GlobalAveragePool | 1- |
| GlobalAveragePool | 1- |
| GlobalLpPool | 1-2- |
| GlobalMaxPool | 1- |
| Greater | 1-7-9 |
| HardSigmoid | 1-6- |
......@@ -110,11 +111,11 @@ opset versions starting from `1` to `6` and to the latest opset version.
### Lack of support in nGraph
| Name | Opset supported | NGCORE | NGONNX | Comment |
|------|-----------------|--------|--------|---------|
| Erf | (9) | 284 | 442 | Need separate kernel for this in nGraph core. |
| Erf | (9) | 284 | 489 | Need separate kernel for this in nGraph core. |
| Pad | 1-2- | 273 | 416 | Not fully supported. |
| LSTM | 1-7- | | 476 | Mixed sequences length not supported yet. |
| MaxUnpool | (9) | 286, 289 | 447 | |
| LpPool | - | 291 | 437 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| LpPool | - | 291 | 488 | Unsupported by nGraph - only max/avg pooling ops. Need separate kernel. |
| Multinomial | - | 199 | 435 | Lack of PRNG in nGraph. |
| RandomNormal | - | 199 | 434 | Lack of PRNG in nGraph. |
| RandomNormalLike | - | 199 | 434 | Lack of PRNG in nGraph. |
......@@ -142,7 +143,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
| OneHot | (9) | NGCORE-339 | 486 | Dynamic output shape
| Tile | - | NGRAPH-3292 | 368 | Dynamic op. |
| Upsample | (7) | 287 | 441 | Dynamic op. |
| MaxRoiPool | - | 288 | 437 | Dynamic op - Need dynamic slicing. Beside just use _slice/op/concat_ pattern. |
| MaxRoiPool | - | 288 | 487 | Dynamic op - Need dynamic slicing. Beside just use _slice/op/concat_ pattern. |
| Reshape | 1-5- | NGRAPH-3290 | 357 | Lack of support for dynamic shape input. Only as a Constant or as an Initializer. |
| Scatter | (9) | 289 | 446 | Dynamic indices input. |
......@@ -151,7 +152,6 @@ opset versions starting from `1` to `6` and to the latest opset version.
|------|-----------------|--------|--------|---------|
| Cast | 1-6- | | 427 | Errors while casting to bool |
| EyeLike | (9) | | 439 | Make constant node. |
| GlobalLpPool | - | | 437 | Probably use _slice/op/concat_ pattern. |
| Hardmax | - | | 431 | Use make constant and Argmax. See `test_ops_unary.py::test_hardmax()` |
| LpNormalization | - | | 436 | Just an equation. Only Lp{1,2} need to be supported. |
| InstanceNormalization | - | | 436 | Just an equation. For per channel computation may _slice/op/concat_ pattern need to be used. |
......
......@@ -62,6 +62,7 @@
#include "op/less.hpp"
#include "op/log.hpp"
#include "op/log_softmax.hpp"
#include "op/lp_pool.hpp"
#include "op/lrn.hpp"
#include "op/lstm.hpp"
#include "op/matmul.hpp"
......@@ -248,6 +249,7 @@ namespace ngraph
REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "norm.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/shape.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace norm
{
namespace detail
{
std::shared_ptr<ngraph::Node> lp_norm(const std::shared_ptr<ngraph::Node>& node,
std::size_t p_norm,
const ngraph::AxisSet& reduction_axes)
{
std::shared_ptr<ngraph::Node> abs_values{
std::make_shared<ngraph::op::Abs>(node)};
std::shared_ptr<ngraph::Node> p_node = ngraph::op::Constant::create(
node->get_element_type(),
node->get_shape(),
std::vector<float>(shape_size(node->get_shape()),
static_cast<float>(p_norm)));
std::shared_ptr<ngraph::Node> values =
std::make_shared<ngraph::op::Power>(abs_values, p_node);
values = std::make_shared<ngraph::op::Sum>(values, reduction_axes);
std::shared_ptr<ngraph::Node> inv_p_node = ngraph::op::Constant::create(
values->get_element_type(),
values->get_shape(),
std::vector<float>(shape_size(values->get_shape()), 1.f / p_norm));
return {std::make_shared<ngraph::op::Power>(values, inv_p_node)};
}
}
std::shared_ptr<ngraph::Node> l0_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes)
{
std::shared_ptr<ngraph::Node> abs_values{std::make_shared<ngraph::op::Abs>(node)};
std::shared_ptr<ngraph::Node> zero_node{ngraph::op::Constant::create(
node->get_element_type(),
node->get_shape(),
std::vector<float>(shape_size(node->get_shape()), 0.f))};
std::shared_ptr<ngraph::Node> non_zero_values =
std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::NotEqual>(abs_values, zero_node),
abs_values->get_element_type());
return std::make_shared<ngraph::op::Sum>(non_zero_values, reduction_axes);
}
std::shared_ptr<ngraph::Node> l1_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes)
{
return std::make_shared<ngraph::op::Sum>(std::make_shared<ngraph::op::Abs>(node),
reduction_axes);
}
std::shared_ptr<ngraph::Node> l2_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes)
{
std::shared_ptr<ngraph::Node> abs_values{std::make_shared<ngraph::op::Abs>(node)};
return {std::make_shared<ngraph::op::Sqrt>(
std::make_shared<ngraph::op::Sum>(abs_values * abs_values, reduction_axes))};
}
std::shared_ptr<ngraph::Node> lp_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes,
std::size_t p_norm)
{
// The number of non-zero elements
if (p_norm == 0)
{
return l0_norm(node, reduction_axes);
}
// sum of absolute values.
else if (p_norm == 1)
{
return l1_norm(node, reduction_axes);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2)
{
return l2_norm(node, reduction_axes);
}
// generic case
else
{
return detail::lp_norm(node, p_norm, reduction_axes);
}
}
} //namespace norm
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace norm
{
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
/// from zero. This actually is not a "true" norm.
///
/// \param[in] node The input tensor node.
/// \param[in] reduction_axes The axes along which we calculate norm.
///
/// \return Node with calculated L-0 norm values.
///
std::shared_ptr<ngraph::Node> l0_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes);
/// \brief Calculates L-1 norm of input tensor.
///
/// \note The L-1 norm represents the sum of absolute values.
///
/// \param[in] node The input tensor node.
/// \param[in] reduction_axes The axes along which we calculate norm.
///
/// \return Node with calculated L-1 norm values.
///
std::shared_ptr<ngraph::Node> l1_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes);
/// \brief Calculates L-2 norm of input tensor.
///
/// \note The L-2 norm represents the square root of sum of squares of each
/// individual element.
///
/// \param[in] node The input tensor node.
/// \param[in] reduction_axes The axes along which we calculate norm.
///
/// \return Node with calculated L-2 norm values.
///
std::shared_ptr<ngraph::Node> l2_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes);
/// \brief Calculates L-p norm on input tensor.
///
/// \param[in] node The input nGraph tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
///
/// \return Resulting L-p norm.
///
std::shared_ptr<ngraph::Node> lp_norm(const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& reduction_axes,
std::size_t p_norm = 2);
} //namespace norm
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef> // std::size_t
#include <vector>
#include "exceptions.hpp"
#include "reduction.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace reduction
{
namespace detail
{
AxisSet get_reduction_axes(const Node& node)
{
auto reduction_axes =
node.get_attribute_value<std::vector<std::size_t>>("axes", {});
if (reduction_axes.empty())
{
reduction_axes = onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size());
}
return AxisSet{reduction_axes};
}
} // namespace detail
std::shared_ptr<ngraph::Node>
make_ng_reduction_op(const Node& node,
const std::shared_ptr<ngraph::Node>& ng_input,
ReductionFunction reduction_function)
{
auto data_shape = ng_input->get_shape();
auto reduction_axes = detail::get_reduction_axes(node);
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
<< "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_shape.size() << ")";
std::shared_ptr<ngraph::Node> op_node =
reduction_function(ng_input, reduction_axes);
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
if (keepdims == 0)
{
return op_node;
}
auto output_shape = data_shape;
// flatten reduced axes and preserve original dimensions count.
for (const auto& idx : reduction_axes)
{
output_shape.at(idx) = 1;
}
return std::make_shared<ngraph::op::Reshape>(
op_node,
reshape::get_default_axis_vector(op_node->get_shape().size()),
Shape{output_shape});
}
} // namespace reduction
} // namespace onnx_import
} // namespace ngraph
......@@ -16,24 +16,14 @@
#pragma once
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end
#include <memory> // std::make_shared
#include <numeric> // std::iota
#include <string>
#include <type_traits> // std::enable_if, std::is_base_of
#include <vector>
#include <cstdint> // std::int64_t
#include <memory> // std::make_shared
#include "core/node.hpp"
#include "exceptions.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "ngraph/shape.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
......@@ -44,69 +34,33 @@ namespace ngraph
{
namespace detail
{
inline AxisSet get_reduction_axes(const Node& node)
{
auto reduction_axes =
node.get_attribute_value<std::vector<std::size_t>>("axes", {});
if (reduction_axes.empty())
{
reduction_axes = onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size());
}
return AxisSet{reduction_axes};
}
AxisSet get_reduction_axes(const Node& node);
} // namespace detail
/// \brief Create an nGraph version of an ONNX reduction operation.
using ReductionFunction = std::function<std::shared_ptr<ngraph::Node>(
const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&)>;
///
/// \param[in] node The node representing incoming ONNX operation.
/// \brief Create an nGraph version of an ONNX reduction operation.
///
/// \tparam OnnxOperator Class of an nGraph ArithmeticReduction operation
/// (e.g. Min, Max, SUm, Product).
/// \param[in] node The node representing incoming ONNX operation.
/// \param[in] ng_input The input (nGraph) Tensor.
/// \param[in] reduction_function The reduction function defining arithmetic reduction
/// operation (e.g. Min, Max, Sum, Product).
///
/// \return nGraph node equivalent of the ONNX operation.
///
template <class OnnxOperator,
typename std::enable_if<std::is_base_of<ngraph::op::util::ArithmeticReduction,
OnnxOperator>::value,
int>::type = 0>
std::shared_ptr<ngraph::Node>
make_ng_reduction_op(const Node& node,
const std::shared_ptr<ngraph::Node>& ng_input)
{
auto data_shape = ng_input->get_shape();
auto reduction_axes = detail::get_reduction_axes(node);
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
<< "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_shape.size() << ")";
auto op_node = std::make_shared<OnnxOperator>(ng_input, reduction_axes);
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
if (keepdims == 0)
{
return op_node;
}
auto output_shape = data_shape;
// flatten reduced axes and preserve original dimensions count.
for (const auto& idx : reduction_axes)
{
output_shape.at(idx) = 1;
}
return std::make_shared<ngraph::op::Reshape>(
op_node,
reshape::get_default_axis_vector(op_node->get_shape().size()),
Shape{output_shape});
}
const std::shared_ptr<ngraph::Node>& ng_input,
ReductionFunction reduction_function);
template <class IndexReduction>
std::shared_ptr<ngraph::Node> make_ng_index_reduction_op(const Node& node)
{
auto axis = node.get_attribute_value<int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<int64_t>("keepdims", 1);
auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0);
auto op_node = std::make_shared<IndexReduction>(input_node, axis, element::i64);
......
ONNXNgraphImporter:i

AB" GlobalLpPool*
p compute_graphZ
A




b
B




B
\ No newline at end of file
ONNXNgraphImporter:i

AB" GlobalLpPool*
p compute_graphZ
A




b
B




B
\ No newline at end of file
ONNXNgraphImporter:i

AB" GlobalLpPool*
p compute_graphZ
A




b
B




B
\ No newline at end of file
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <fstream>
......@@ -2055,6 +2056,65 @@ TEST(onnx_${BACKEND_NAME}, model_sign)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p0)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/global_lp_pool_p0.onnx"));
std::vector<std::vector<std::int64_t>> inputs{std::vector<std::int64_t>{
1, 0, -4, 0, 2, 1, -6, 1, 0, 0, 0, 0, -7, 1, -1, 0, -1, 8, 0, 10, 9, 0, 0, 5}};
std::vector<std::vector<std::int64_t>> expected_outputs{std::vector<std::int64_t>{6, 8}};
std::vector<std::vector<std::int64_t>> outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p1)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/global_lp_pool_p1.onnx"));
Inputs inputs{std::vector<float>(2 * 3 * 4)};
std::iota(std::begin(inputs.front()), std::end(inputs.front()), 0.f);
Outputs expected_outputs{std::vector<float>{66.f, 210.f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p2)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/global_lp_pool_p2.onnx"));
Inputs inputs{std::vector<float>(2 * 3 * 4)};
std::iota(std::begin(inputs.front()), std::end(inputs.front()), 0.f);
Outputs expected_outputs{std::vector<float>{22.494444f, 61.789967f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p3)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/global_lp_pool_p3.onnx"));
Inputs inputs{std::vector<float>(2 * 3 * 4)};
std::iota(std::begin(inputs.front()), std::end(inputs.front()), 0.f);
Outputs expected_outputs{std::vector<float>{16.331620904278438f, 41.56697946707537f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_one_hot_with_axis)
{
auto function = onnx_import::import_onnx_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment