Commit 0fc8f4d8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Update Reduction ops to use v1 operators. (#4084)

* Fix spelling, and comment formatting.

* Update Reduction operation to use v1 operators.

* Upgrade/downgrade passess for ReduceMin/Max ops.

* Address review comments:

- Remove unnecessary AutoBroadcast arg.
- Use default_opset namespace.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 563c0c21
......@@ -25,7 +25,7 @@ namespace ngraph
{
namespace builder
{
/// \brief Specyfies method of bias application to avoid numerical problems
/// \brief Specifies method of bias application to avoid numerical problems
enum class BiasMode
{
// Add bias to intermediate result
......
......@@ -45,9 +45,10 @@ namespace ngraph
auto sum_node = std::shared_ptr<ngraph::Node>{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Sum,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
const std::shared_ptr<ngraph::Node>&,
bool>)};
auto const_node = default_opset::Constant::create(
sum_node->get_element_type(),
......@@ -55,7 +56,7 @@ namespace ngraph
std::vector<std::size_t>(shape_size(sum_node->get_shape()),
elem_count_product));
return {std::make_shared<ngraph::opset0::Divide>(sum_node, const_node)};
return {std::make_shared<default_opset::Divide>(sum_node, const_node)};
}
} // namespace set_1
......
......@@ -23,17 +23,6 @@
#include "default_opset.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/reduction.hpp"
namespace ngraph
......@@ -61,9 +50,10 @@ namespace ngraph
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Sum,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
......@@ -86,9 +76,10 @@ namespace ngraph
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node,
exp_node,
std::make_shared<ngraph::opset0::Sum,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)};
}
......@@ -155,9 +146,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Max,
std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
bool>)};
}
/// \brief Compute the mean value of the input tensor's elements along the
......@@ -191,9 +183,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op(
node,
node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Min,
std::make_shared<default_opset::ReduceMin,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
bool>)};
}
/// \brief Compute the product of the input tensor's elements along the
......@@ -257,13 +250,14 @@ namespace ngraph
inline NodeVector reduce_sum_square(const Node& node)
{
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = input * input;
auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op(
node,
square_node,
std::make_shared<ngraph::opset0::Sum,
std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)};
bool>)};
}
} // namespace set_1
......
......@@ -68,8 +68,7 @@ namespace ngraph
/// \param[in] node The node representing incoming ONNX operation.
/// \param[in] ng_input The input (nGraph) Tensor.
/// \param[in] reduction_function The reduction function defining arithmetic dynamic
/// reduction
/// operation (e.g. ReduceProd, ReduceSum).
/// reduction operation (e.g. ReduceProd, ReduceSum).
///
/// \return nGraph node equivalent of the ONNX operation.
///
......
......@@ -46,7 +46,48 @@ namespace
replace_node(node, replacement_node);
}
// Default is that we didn nothing
template <typename OpV0, typename OpV1>
void op_cast_reduction_node(const shared_ptr<OpV1>& node)
{
auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
{
string v1_op_name = string{node->get_type_name()} + ":v1";
string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ",
v1_op_name,
"to ",
v0_op_name,
" if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ",
v1_op_name,
"to ",
v0_op_name,
" if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
}
else
{
replace_node(node, replacement_node);
}
}
// Default is that we did nothing
bool op_cast(shared_ptr<Node> node) { return false; }
bool op_cast(shared_ptr<op::v1::Add> node)
{
......@@ -549,36 +590,27 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceMax> node)
{
op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceMin> node)
{
op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceProd> node)
{
auto replacement_node =
make_shared<op::v0::Product>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
{
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ReduceProd:v1 to Product:v0 "
"if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ReduceProd:v1 to Product:v0 "
"if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
}
else
{
replace_node(node, replacement_node);
}
op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{
op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
return true;
}
......@@ -716,39 +748,6 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{
auto replacement_node =
make_shared<op::v0::Sum>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
{
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ReduceSum:v1 to Sum:v0 "
"if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ReduceSum:v1 to Sum:v0 "
"if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
}
else
{
replace_node(node, replacement_node);
}
return true;
}
bool op_cast(shared_ptr<op::v1::TopK> node)
{
const auto axis = node->get_axis();
......
......@@ -379,6 +379,15 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::Max> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Maximum> node)
{
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
......@@ -439,6 +448,15 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::Min> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Minimum> node)
{
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
......
......@@ -85,12 +85,11 @@ set(SRC
opset_pass/generate_mask_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp
opset_pass/product_opset_pass.cpp
opset_pass/reduction_opset_pass.cpp
opset_pass/reverse_opset_pass.cpp
opset_pass/select_opset_pass.cpp
opset_pass/slice_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/sum_opset_pass.cpp
opset_pass/topk_opset_pass.cpp
partial_shape.cpp
pass.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_product_upgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const AxisSet reduction_axes{1, 2};
const auto product_v0 = make_shared<op::Product>(data, reduction_axes);
const auto result = make_shared<op::Result>(product_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_prod_v1 = as_type_ptr<op::v1::ReduceProd>(pass_replacement_node);
ASSERT_TRUE(reduce_prod_v1);
EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
ASSERT_TRUE(reshape);
const auto product_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto product_v0 = as_type_ptr<op::v0::Product>(product_replace_node);
ASSERT_TRUE(product_v0);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Parameter>(element::f32, Shape{1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Unable to convert ReduceProd:v1 to Product:v0 "
"if reduction axes are not constant (for keep_dims=true)"));
}
catch (...)
{
FAIL() << "ReduceProd pass failed for unexpected reason";
}
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_output_not_static)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert ReduceProd:v1 to Product:v0 "
"if output shape is dynamic (for keep_dims=true)"));
}
catch (...)
{
FAIL() << "ReduceProd pass failed for unexpected reason";
}
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_keep_dims)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = true;
auto reduce_prod_v1 = make_shared<op::v1::ReduceProd>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_prod_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3, 1, 1}));
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_not_keep_dims)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = false;
auto reduce_prod_v1 = make_shared<op::v1::ReduceProd>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_prod_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment