Unverified Commit 39ad8e42 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Produce v1::ReduceMean in ONNX ReduceMean (#4410)

* Use v1::ReduceMean in ONNX ReduceMean

* Add support for keep_dims

* Replace binds with lambdas
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d3bce642
......@@ -14,13 +14,14 @@
// limitations under the License.
//*****************************************************************************
#include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end
#include <numeric> // std::accumulate
#include <functional>
#include <memory>
#include "default_opset.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp"
#include "reduce.hpp"
#include "utils/reduction.hpp"
namespace ngraph
{
......@@ -30,55 +31,120 @@ namespace ngraph
{
namespace set_1
{
NodeVector reduce_mean(const Node& node)
NodeVector reduce_log_sum(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
const auto& data_shape = data->get_output_partial_shape(0);
// sum up the input data along the reduction axes
const auto sum_node = reduction::make_ng_reduction_op(
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
data,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>);
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
// calculate the product of dimensions pointed to by reduction axes
size_t reduced_elems_count = 1U;
NodeVector reduce_log_sum_exp(const Node& node)
{
auto exp_node =
std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0));
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
if (data_shape.is_static())
NodeVector reduce_l1(const Node& node)
{
const auto input_shape = data_shape.to_shape();
auto l1_norm_reduction = [](const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& axis_set) {
return ngraph::builder::opset1::l1_norm(node, axis_set, 0.f);
};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l1_norm_reduction)};
}
// calculate the product of dimensions pointed to by reduction axes
// this value represents the number of input tensor values that were reduced
for (const auto axis : reduction::detail::get_reduction_axes(node))
NodeVector reduce_l2(const Node& node)
{
reduced_elems_count *= input_shape.at(axis);
auto l2_norm_reduction = [](const std::shared_ptr<ngraph::Node>& node,
const ngraph::AxisSet& axis_set) {
return ngraph::builder::opset1::l2_norm(
node, axis_set, 0.f, ngraph::builder::BiasMode::ADD, false);
};
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l2_norm_reduction)};
}
NodeVector reduce_max(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
else
NodeVector reduce_mean(const Node& node)
{
for (const auto axis : reduction::detail::get_reduction_axes(node))
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMean,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_min(const Node& node)
{
const auto dim_to_reduce = data_shape[axis];
NGRAPH_CHECK(dim_to_reduce.is_static(),
"Axis ",
axis,
" in the input data tensor needs to be statically "
"specified to create a ReduceMean operation");
reduced_elems_count *= dim_to_reduce.get_length();
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMin,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_prod(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceProd,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
const auto const_node = default_opset::Constant::create(
sum_node->get_element_type(), {}, {reduced_elems_count});
NodeVector reduce_sum(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
// divide the sum node containing reduced values by the number
// of those values to obtain the mean
return {std::make_shared<default_opset::Divide>(sum_node, const_node)};
NodeVector reduce_sum_square(const Node& node)
{
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
} // namespace set_1
......
......@@ -16,14 +16,7 @@
#pragma once
#include <functional>
#include <memory>
#include "core/node.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp"
#include "utils/reduction.hpp"
namespace ngraph
{
......@@ -45,17 +38,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_log_sum(const Node& node)
{
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
NodeVector reduce_log_sum(const Node& node);
/// \brief Compute the log sum exponent of the input tensor's elements along
/// the provided axes.
......@@ -69,19 +52,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_log_sum_exp(const Node& node)
{
auto exp_node =
std::make_shared<default_opset::Exp>(node.get_ng_inputs().at(0));
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
NodeVector reduce_log_sum_exp(const Node& node);
/// \brief Compute the L1 norm of the input tensor's element along the provided
/// axes.
......@@ -95,15 +66,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_l1(const Node& node)
{
auto l1_norm_reduction = std::bind(ngraph::builder::opset1::l1_norm,
std::placeholders::_1,
std::placeholders::_2,
0.f);
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l1_norm_reduction)};
}
NodeVector reduce_l1(const Node& node);
/// \brief Compute the L2 norm of the input tensor's element along the provided
/// axes.
......@@ -117,17 +80,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_l2(const Node& node)
{
auto l2_norm_reduction = std::bind(ngraph::builder::opset1::l2_norm,
std::placeholders::_1,
std::placeholders::_2,
0.f,
ngraph::builder::BiasMode::ADD,
false);
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l2_norm_reduction)};
}
NodeVector reduce_l2(const Node& node);
/// \brief Compute the maximum value of the input tensor's elements along the
/// provided axes.
......@@ -141,16 +94,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_max(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_max(const Node& node);
/// \brief Compute the mean value of the input tensor's elements along the
/// provided axes.
......@@ -178,16 +122,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_min(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceMin,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_min(const Node& node);
/// \brief Compute the product of the input tensor's elements along the
/// provided axes.
......@@ -201,16 +136,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_prod(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceProd,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_prod(const Node& node);
/// \brief Compute the sum of the input tensor's elements along the provided
/// axes.
......@@ -224,16 +150,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_sum(const Node& node)
{
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_sum(const Node& node);
/// \brief Compute the sum square of the input tensor's element along the
/// provided axes.
......@@ -247,18 +164,7 @@ namespace ngraph
///
/// \return The nGraph node equivalent of the ONNX operation.
///
inline NodeVector reduce_sum_square(const Node& node)
{
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
bool>)};
}
NodeVector reduce_sum_square(const Node& node);
} // namespace set_1
......
......@@ -133,8 +133,8 @@ namespace ngraph
/// A value of -1 is allowed for at most one dimension, in which case the
/// dimension size is inferred based on element count of input tensor.
/// \param special_zero Treats zeros in `pattern` as wildcard flags indicating a
/// copy
/// from input shape at the same index.
/// copy from input shape at the same index.
///
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
bool visit_attributes(AttributeVisitor& visitor) override;
......
......@@ -81,14 +81,13 @@ namespace
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
return reshaped_product;
}
else
{
replace_node(node, replacement_node);
}
return replacement_node;
}
}
// Default is that we did nothing
shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
......@@ -608,22 +607,63 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
{
return op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMean> node)
{
// ReduceMean = Sum / Count
auto sum_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceMean>(node);
// Count = Sum(Constant(1, shape=data.shape))
const auto data = node->input_value(0);
const auto axes = node->input_value(1);
const auto const_node =
op::v0::Constant::create(data.get_element_type(), data.get_shape(), {1});
std::shared_ptr<Node> count_node = std::make_shared<op::v0::Sum>(const_node, axes);
// Support keep_dims attribute
if (node->get_keep_dims())
{
// In order to keep the original dimensions we need to reshape the Count node
// before we use it in Divide with NUMPY broadcast
auto output_shape = count_node->get_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
count_node = make_shared<op::Reshape>(
count_node->output(0), get_default_order(output_shape), reshaped_output_shape);
}
const auto replacement_node =
std::make_shared<op::v0::Divide>(sum_node, count_node, op::AutoBroadcastSpec::NUMPY);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
{
return op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
auto replacement_node = op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
{
return op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
auto replacement_node = op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
{
return op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
auto replacement_node = op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
replace_node(node, replacement_node);
return replacement_node;
}
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
......
......@@ -5,6 +5,11 @@ graph {
input: "A"
output: "B"
op_type: "ReduceMean"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "compute_graph"
input {
......@@ -36,7 +41,6 @@ graph {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
......
......@@ -932,7 +932,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reduce_mean)
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_multiple_inputs(inputs);
test_case.add_expected_output(expected_output);
test_case.add_expected_output(Shape{}, expected_output);
test_case.run();
}
......
......@@ -150,8 +150,12 @@ namespace ngraph
auto function_output_type = results.at(m_output_index)->get_element_type();
const auto& output_pshape = results.at(m_output_index)->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.compatible(expected_shape),
"Passed output shape is not compatible with nGraph function.");
NGRAPH_CHECK(
output_pshape.compatible(expected_shape),
"nGraph function generated an unexpected output shape. Expected shape: ",
expected_shape,
" Output shape: ",
output_pshape);
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, expected_shape, values));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment