Unverified Commit 39ad8e42 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Produce v1::ReduceMean in ONNX ReduceMean (#4410)

* Use v1::ReduceMean in ONNX ReduceMean

* Add support for keep_dims

* Replace binds with lambdas
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d3bce642
...@@ -14,13 +14,14 @@ ...@@ -14,13 +14,14 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstddef> // std::size_t #include <functional>
#include <iterator> // std::begin, std::end #include <memory>
#include <numeric> // std::accumulate
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/shape.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp"
#include "reduce.hpp" #include "reduce.hpp"
#include "utils/reduction.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,55 +31,120 @@ namespace ngraph ...@@ -30,55 +31,120 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
NodeVector reduce_mean(const Node& node) NodeVector reduce_log_sum(const Node& node)
{ {
const auto data = node.get_ng_inputs().at(0); std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
const auto& data_shape = data->get_output_partial_shape(0);
// sum up the input data along the reduction axes
const auto sum_node = reduction::make_ng_reduction_op(
node, node,
data, node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
bool>); bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
// calculate the product of dimensions pointed to by reduction axes NodeVector reduce_log_sum_exp(const Node& node)
size_t reduced_elems_count = 1U; {
auto exp_node =
std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0));
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
if (data_shape.is_static()) NodeVector reduce_l1(const Node& node)
{ {
const auto input_shape = data_shape.to_shape(); auto l1_norm_reduction = [](const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& axis_set) {
return ngraph::builder::opset1::l1_norm(node, axis_set, 0.f);
};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l1_norm_reduction)};
}
// calculate the product of dimensions pointed to by reduction axes NodeVector reduce_l2(const Node& node)
// this value represents the number of input tensor values that were reduced
for (const auto axis : reduction::detail::get_reduction_axes(node))
{ {
reduced_elems_count *= input_shape.at(axis); auto l2_norm_reduction = [](const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& axis_set) {
return ngraph::builder::opset1::l2_norm(
node, axis_set, 0.f, ngraph::builder::BiasMode::ADD, false);
};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l2_norm_reduction)};
} }
NodeVector reduce_max(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
} }
else
NodeVector reduce_mean(const Node& node)
{ {
for (const auto axis : reduction::detail::get_reduction_axes(node)) return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMean,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_min(const Node& node)
{ {
const auto dim_to_reduce = data_shape[axis]; return {reduction::make_ng_reduction_op(
NGRAPH_CHECK(dim_to_reduce.is_static(), node,
"Axis ", node.get_ng_inputs().at(0),
axis, std::make_shared<default_opset::ReduceMin,
" in the input data tensor needs to be statically " const std::shared_ptr<ngraph::Node>&,
"specified to create a ReduceMean operation"); const std::shared_ptr<ngraph::Node>&,
bool>)};
reduced_elems_count *= dim_to_reduce.get_length();
} }
NodeVector reduce_prod(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceProd,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
} }
const auto const_node = default_opset::Constant::create( NodeVector reduce_sum(const Node& node)
sum_node->get_element_type(), {}, {reduced_elems_count}); {
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
// divide the sum node containing reduced values by the number NodeVector reduce_sum_square(const Node& node)
// of those values to obtain the mean {
return {std::make_shared<default_opset::Divide>(sum_node, const_node)}; auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,14 +16,7 @@ ...@@ -16,14 +16,7 @@
#pragma once #pragma once
#include <functional>
#include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp"
#include "utils/reduction.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -45,17 +38,7 @@ namespace ngraph ...@@ -45,17 +38,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_log_sum(const Node& node) NodeVector reduce_log_sum(const Node& node);
{
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
/// \brief Compute the log sum exponent of the input tensor's elements along /// \brief Compute the log sum exponent of the input tensor's elements along
/// the provided axes. /// the provided axes.
...@@ -69,19 +52,7 @@ namespace ngraph ...@@ -69,19 +52,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_log_sum_exp(const Node& node) NodeVector reduce_log_sum_exp(const Node& node);
{
auto exp_node =
std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0));
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
/// \brief Compute the L1 norm of the input tensor's element along the provided /// \brief Compute the L1 norm of the input tensor's element along the provided
/// axes. /// axes.
...@@ -95,15 +66,7 @@ namespace ngraph ...@@ -95,15 +66,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_l1(const Node& node) NodeVector reduce_l1(const Node& node);
{
auto l1_norm_reduction = std::bind(ngraph::builder::opset1::l1_norm,
std::placeholders::_1,
std::placeholders::_2,
0.f);
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l1_norm_reduction)};
}
/// \brief Compute the L2 norm of the input tensor's element along the provided /// \brief Compute the L2 norm of the input tensor's element along the provided
/// axes. /// axes.
...@@ -117,17 +80,7 @@ namespace ngraph ...@@ -117,17 +80,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_l2(const Node& node) NodeVector reduce_l2(const Node& node);
{
auto l2_norm_reduction = std::bind(ngraph::builder::opset1::l2_norm,
std::placeholders::_1,
std::placeholders::_2,
0.f,
ngraph::builder::BiasMode::ADD,
false);
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l2_norm_reduction)};
}
/// \brief Compute the maximum value of the input tensor's elements along the /// \brief Compute the maximum value of the input tensor's elements along the
/// provided axes. /// provided axes.
...@@ -141,16 +94,7 @@ namespace ngraph ...@@ -141,16 +94,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_max(const Node& node) NodeVector reduce_max(const Node& node);
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
/// \brief Compute the mean value of the input tensor's elements along the /// \brief Compute the mean value of the input tensor's elements along the
/// provided axes. /// provided axes.
...@@ -178,16 +122,7 @@ namespace ngraph ...@@ -178,16 +122,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_min(const Node& node) NodeVector reduce_min(const Node& node);
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMin,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
/// \brief Compute the product of the input tensor's elements along the /// \brief Compute the product of the input tensor's elements along the
/// provided axes. /// provided axes.
...@@ -201,16 +136,7 @@ namespace ngraph ...@@ -201,16 +136,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_prod(const Node& node) NodeVector reduce_prod(const Node& node);
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceProd,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
/// \brief Compute the sum of the input tensor's elements along the provided /// \brief Compute the sum of the input tensor's elements along the provided
/// axes. /// axes.
...@@ -224,16 +150,7 @@ namespace ngraph ...@@ -224,16 +150,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_sum(const Node& node) NodeVector reduce_sum(const Node& node);
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
/// \brief Compute the sum square of the input tensor's element along the /// \brief Compute the sum square of the input tensor's element along the
/// provided axes. /// provided axes.
...@@ -247,18 +164,7 @@ namespace ngraph ...@@ -247,18 +164,7 @@ namespace ngraph
/// ///
/// \return The nGraph node equivalent of the ONNX operation. /// \return The nGraph node equivalent of the ONNX operation.
/// ///
inline NodeVector reduce_sum_square(const Node& node) NodeVector reduce_sum_square(const Node& node);
{
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
} // namespace set_1 } // namespace set_1
......
...@@ -133,8 +133,8 @@ namespace ngraph ...@@ -133,8 +133,8 @@ namespace ngraph
/// A value of -1 is allowed for at most one dimension, in which case the /// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor. /// dimension size is inferred based on element count of input tensor.
/// \param special_zero Treats zeros in `pattern` as wildcard flags indicating a /// \param special_zero Treats zeros in `pattern` as wildcard flags indicating a
/// copy /// copy from input shape at the same index.
/// from input shape at the same index. ///
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero); Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
......
...@@ -81,14 +81,13 @@ namespace ...@@ -81,14 +81,13 @@ namespace
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0), auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape), get_default_order(output_shape),
reshaped_output_shape); reshaped_output_shape);
replace_node(node, reshaped_product); return reshaped_product;
} }
else else
{ {
replace_node(node, replacement_node);
}
return replacement_node; return replacement_node;
} }
}
// Default is that we did nothing // Default is that we did nothing
shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; } shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
...@@ -608,22 +607,63 @@ namespace ...@@ -608,22 +607,63 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
{ {
return op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node); auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
{
// ReduceMean = Sum / Count
auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
// Count = Sum(Constant(1, shape=data.shape))
const auto data = node->input_value(0);
const auto axes = node->input_value(1);
const auto const_node =
op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
// Support keep_dims attribute
if (node->get_keep_dims())
{
// In order to keep the original dimensions we need to reshape the Count node
// before we use it in Divide with NUMPY broadcast
auto output_shape = count_node->get_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
count_node = make_shared<op::Reshape>(
count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
}
const auto replacement_node =
std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
replace_node(node, replacement_node);
return replacement_node;
} }
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
{ {
return op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node); auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
replace_node(node, replacement_node);
return replacement_node;
} }
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
{ {
return op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node); auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
replace_node(node, replacement_node);
return replacement_node;
} }
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
{ {
return op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node); auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
replace_node(node, replacement_node);
return replacement_node;
} }
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
......
...@@ -5,6 +5,11 @@ graph { ...@@ -5,6 +5,11 @@ graph {
input: "A" input: "A"
output: "B" output: "B"
op_type: "ReduceMean" op_type: "ReduceMean"
attribute {
name: "keepdims"
i: 0
type: INT
}
} }
name: "compute_graph" name: "compute_graph"
input { input {
...@@ -36,7 +41,6 @@ graph { ...@@ -36,7 +41,6 @@ graph {
elem_type: 1 elem_type: 1
shape { shape {
dim { dim {
dim_value: 1
} }
} }
} }
......
...@@ -932,7 +932,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reduce_mean) ...@@ -932,7 +932,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reduce_mean)
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}"); auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_multiple_inputs(inputs); test_case.add_multiple_inputs(inputs);
test_case.add_expected_output(expected_output); test_case.add_expected_output(Shape{}, expected_output);
test_case.run(); test_case.run();
} }
......
...@@ -150,8 +150,12 @@ namespace ngraph ...@@ -150,8 +150,12 @@ namespace ngraph
auto function_output_type = results.at(m_output_index)->get_element_type(); auto function_output_type = results.at(m_output_index)->get_element_type();
const auto& output_pshape = results.at(m_output_index)->get_output_partial_shape(0); const auto& output_pshape = results.at(m_output_index)->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.compatible(expected_shape), NGRAPH_CHECK(
"Passed output shape is not compatible with nGraph function."); output_pshape.compatible(expected_shape),
"nGraph function generated an unexpected output shape. Expected shape: ",
expected_shape,
" Output shape: ",
output_pshape);
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>( m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, expected_shape, values)); function_output_type, expected_shape, values));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment