Commit 0fc8f4d8 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[ONNX] Update Reduction ops to use v1 operators. (#4084)

* Fix spelling, and comment formatting.

* Update Reduction operation to use v1 operators.

* Upgrade/downgrade passess for ReduceMin/Max ops.

* Address review comments:

- Remove unnecessary AutoBroadcast arg.
- Use default_opset namespace.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 563c0c21
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
{ {
namespace builder namespace builder
{ {
/// \brief Specyfies method of bias application to avoid numerical problems /// \brief Specifies method of bias application to avoid numerical problems
enum class BiasMode enum class BiasMode
{ {
// Add bias to intermediate result // Add bias to intermediate result
......
...@@ -45,9 +45,10 @@ namespace ngraph ...@@ -45,9 +45,10 @@ namespace ngraph
auto sum_node = std::shared_ptr<ngraph::Node>{reduction::make_ng_reduction_op( auto sum_node = std::shared_ptr<ngraph::Node>{reduction::make_ng_reduction_op(
node, node,
node.get_ng_inputs().at(0), node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Sum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; const std::shared_ptr<ngraph::Node>&,
bool>)};
auto const_node = default_opset::Constant::create( auto const_node = default_opset::Constant::create(
sum_node->get_element_type(), sum_node->get_element_type(),
...@@ -55,7 +56,7 @@ namespace ngraph ...@@ -55,7 +56,7 @@ namespace ngraph
std::vector<std::size_t>(shape_size(sum_node->get_shape()), std::vector<std::size_t>(shape_size(sum_node->get_shape()),
elem_count_product)); elem_count_product));
return {std::make_shared<ngraph::opset0::Divide>(sum_node, const_node)}; return {std::make_shared<default_opset::Divide>(sum_node, const_node)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -23,17 +23,6 @@ ...@@ -23,17 +23,6 @@
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "utils/reduction.hpp" #include "utils/reduction.hpp"
namespace ngraph namespace ngraph
...@@ -61,9 +50,10 @@ namespace ngraph ...@@ -61,9 +50,10 @@ namespace ngraph
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op( std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node, node,
node.get_ng_inputs().at(0), node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Sum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; bool>)};
return {std::make_shared<default_opset::Log>(sum_node)}; return {std::make_shared<default_opset::Log>(sum_node)};
} }
...@@ -86,9 +76,10 @@ namespace ngraph ...@@ -86,9 +76,10 @@ namespace ngraph
std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op( std::shared_ptr<ngraph::Node> sum_node{reduction::make_ng_reduction_op(
node, node,
exp_node, exp_node,
std::make_shared<ngraph::opset0::Sum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; const std::shared_ptr<ngraph::Node>&,
bool>)};
return {std::make_shared<default_opset::Log>(sum_node)}; return {std::make_shared<default_opset::Log>(sum_node)};
} }
...@@ -155,9 +146,10 @@ namespace ngraph ...@@ -155,9 +146,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op( return {reduction::make_ng_reduction_op(
node, node,
node.get_ng_inputs().at(0), node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Max, std::make_shared<default_opset::ReduceMax,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; bool>)};
} }
/// \brief Compute the mean value of the input tensor's elements along the /// \brief Compute the mean value of the input tensor's elements along the
...@@ -191,9 +183,10 @@ namespace ngraph ...@@ -191,9 +183,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op( return {reduction::make_ng_reduction_op(
node, node,
node.get_ng_inputs().at(0), node.get_ng_inputs().at(0),
std::make_shared<ngraph::opset0::Min, std::make_shared<default_opset::ReduceMin,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; bool>)};
} }
/// \brief Compute the product of the input tensor's elements along the /// \brief Compute the product of the input tensor's elements along the
...@@ -257,13 +250,14 @@ namespace ngraph ...@@ -257,13 +250,14 @@ namespace ngraph
inline NodeVector reduce_sum_square(const Node& node) inline NodeVector reduce_sum_square(const Node& node)
{ {
auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)}; auto input = std::shared_ptr<ngraph::Node>{node.get_ng_inputs().at(0)};
auto square_node = input * input; auto square_node = std::make_shared<default_opset::Multiply>(input, input);
return {reduction::make_ng_reduction_op( return {reduction::make_ng_reduction_op(
node, node,
square_node, square_node,
std::make_shared<ngraph::opset0::Sum, std::make_shared<default_opset::ReduceSum,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const ngraph::AxisSet&>)}; bool>)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -68,8 +68,7 @@ namespace ngraph ...@@ -68,8 +68,7 @@ namespace ngraph
/// \param[in] node The node representing incoming ONNX operation. /// \param[in] node The node representing incoming ONNX operation.
/// \param[in] ng_input The input (nGraph) Tensor. /// \param[in] ng_input The input (nGraph) Tensor.
/// \param[in] reduction_function The reduction function defining arithmetic dynamic /// \param[in] reduction_function The reduction function defining arithmetic dynamic
/// reduction /// reduction operation (e.g. ReduceProd, ReduceSum).
/// operation (e.g. ReduceProd, ReduceSum).
/// ///
/// \return nGraph node equivalent of the ONNX operation. /// \return nGraph node equivalent of the ONNX operation.
/// ///
......
...@@ -46,7 +46,48 @@ namespace ...@@ -46,7 +46,48 @@ namespace
replace_node(node, replacement_node); replace_node(node, replacement_node);
} }
// Default is that we didn nothing template <typename OpV0, typename OpV1>
void op_cast_reduction_node(const shared_ptr<OpV1>& node)
{
auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
{
string v1_op_name = string{node->get_type_name()} + ":v1";
string v0_op_name = string{OpV0{}.get_type_name()} + ":v0";
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ",
v1_op_name,
"to ",
v0_op_name,
" if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ",
v1_op_name,
"to ",
v0_op_name,
" if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
}
else
{
replace_node(node, replacement_node);
}
}
// Default is that we did nothing
bool op_cast(shared_ptr<Node> node) { return false; } bool op_cast(shared_ptr<Node> node) { return false; }
bool op_cast(shared_ptr<op::v1::Add> node) bool op_cast(shared_ptr<op::v1::Add> node)
{ {
...@@ -549,36 +590,27 @@ namespace ...@@ -549,36 +590,27 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceProd> node) bool op_cast(shared_ptr<op::v1::ReduceMax> node)
{ {
auto replacement_node = op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
make_shared<op::v0::Product>(node->input_value(0), node->input_value(1)); return true;
if (node->get_keep_dims())
{
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ReduceProd:v1 to Product:v0 "
"if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ReduceProd:v1 to Product:v0 "
"if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
} }
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape), bool op_cast(shared_ptr<op::v1::ReduceMin> node)
reshaped_output_shape); {
replace_node(node, reshaped_product); op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
return true;
} }
else
bool op_cast(shared_ptr<op::v1::ReduceProd> node)
{ {
replace_node(node, replacement_node); op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{
op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
return true; return true;
} }
...@@ -716,39 +748,6 @@ namespace ...@@ -716,39 +748,6 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
{
auto replacement_node =
make_shared<op::v0::Sum>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
{
NGRAPH_CHECK(node->reduction_axes_constant(),
"Unable to convert ReduceSum:v1 to Sum:v0 "
"if reduction axes are not constant (for keep_dims=true). Node: ",
*node);
auto output_pshape = replacement_node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(),
"Unable to convert ReduceSum:v1 to Sum:v0 "
"if output shape is dynamic (for keep_dims=true). Node: ",
*node);
const auto output_shape = output_pshape.to_shape();
auto reshaped_output_shape = output_shape;
for (const auto& axis : node->get_reduction_axes())
{
reshaped_output_shape.insert(reshaped_output_shape.begin() + axis, 1);
}
auto reshaped_product = make_shared<op::Reshape>(replacement_node->output(0),
get_default_order(output_shape),
reshaped_output_shape);
replace_node(node, reshaped_product);
}
else
{
replace_node(node, replacement_node);
}
return true;
}
bool op_cast(shared_ptr<op::v1::TopK> node) bool op_cast(shared_ptr<op::v1::TopK> node)
{ {
const auto axis = node->get_axis(); const auto axis = node->get_axis();
......
...@@ -379,6 +379,15 @@ namespace ...@@ -379,6 +379,15 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::Max> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Maximum> node) bool op_cast(shared_ptr<op::Maximum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node); op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
...@@ -439,6 +448,15 @@ namespace ...@@ -439,6 +448,15 @@ namespace
return true; return true;
} }
bool op_cast(shared_ptr<op::Min> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::Minimum> node) bool op_cast(shared_ptr<op::Minimum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node); op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
......
...@@ -85,12 +85,11 @@ set(SRC ...@@ -85,12 +85,11 @@ set(SRC
opset_pass/generate_mask_opset_pass.cpp opset_pass/generate_mask_opset_pass.cpp
opset_pass/pad_opset_pass.cpp opset_pass/pad_opset_pass.cpp
opset_pass/poolings_opset_pass.cpp opset_pass/poolings_opset_pass.cpp
opset_pass/product_opset_pass.cpp opset_pass/reduction_opset_pass.cpp
opset_pass/reverse_opset_pass.cpp opset_pass/reverse_opset_pass.cpp
opset_pass/select_opset_pass.cpp opset_pass/select_opset_pass.cpp
opset_pass/slice_opset_pass.cpp opset_pass/slice_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp opset_pass/softmax_opset_pass.cpp
opset_pass/sum_opset_pass.cpp
opset_pass/topk_opset_pass.cpp opset_pass/topk_opset_pass.cpp
partial_shape.cpp partial_shape.cpp
pass.cpp pass.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_product_upgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const AxisSet reduction_axes{1, 2};
const auto product_v0 = make_shared<op::Product>(data, reduction_axes);
const auto result = make_shared<op::Result>(product_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reduce_prod_v1 = as_type_ptr<op::v1::ReduceProd>(pass_replacement_node);
ASSERT_TRUE(reduce_prod_v1);
EXPECT_EQ(reduce_prod_v1->get_keep_dims(), false);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto reshape_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node);
ASSERT_TRUE(reshape);
const auto product_replace_node =
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto product_v0 = as_type_ptr<op::v0::Product>(product_replace_node);
ASSERT_TRUE(product_v0);
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Parameter>(element::f32, Shape{1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Unable to convert ReduceProd:v1 to Product:v0 "
"if reduction axes are not constant (for keep_dims=true)"));
}
catch (...)
{
FAIL() << "ReduceProd pass failed for unexpected reason";
}
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_output_not_static)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto product_v1 = make_shared<op::v1::ReduceProd>(data, axes, true);
const auto result = make_shared<op::Result>(product_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert ReduceProd:v1 to Product:v0 "
"if output shape is dynamic (for keep_dims=true)"));
}
catch (...)
{
FAIL() << "ReduceProd pass failed for unexpected reason";
}
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_keep_dims)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = true;
auto reduce_prod_v1 = make_shared<op::v1::ReduceProd>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_prod_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3, 1, 1}));
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_not_keep_dims)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = false;
auto reduce_prod_v1 = make_shared<op::v1::ReduceProd>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_prod_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3}));
}
//***************************************************************************** //*****************************************************************************
// Copyright 2017-2020 Intel Corporation // Copyright 2017-2019 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -26,56 +25,68 @@ ...@@ -26,56 +25,68 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(opset_transform, opset1_reduce_sum_upgrade_pass) //------------------------------------------------------------------------------
//
// Helper Functions
//
//------------------------------------------------------------------------------
template <typename OpV0, typename OpV1>
void test_reduce_op_opset1_upgrade_pass()
{ {
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3}); const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const AxisSet reduction_axes{1, 2}; const AxisSet reduction_axes{1, 2};
const auto sum_v0 = make_shared<op::Sum>(data, reduction_axes); const auto v0_node = make_shared<OpV0>(data, reduction_axes);
const auto result = make_shared<op::Result>(sum_v0); const auto result = make_shared<op::Result>(v0_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>(); pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto pass_replacement_node = const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); const auto v1_node = as_type_ptr<OpV1>(pass_replacement_node);
const auto reduce_sum_v1 = as_type_ptr<op::v1::ReduceSum>(pass_replacement_node);
ASSERT_TRUE(reduce_sum_v1); ASSERT_TRUE(v1_node);
EXPECT_EQ(reduce_sum_v1->get_keep_dims(), false); EXPECT_EQ(v1_node->get_keep_dims(), false);
EXPECT_EQ(v1_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(v1_node->output(0).get_shape(), (Shape{1}));
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass) template <typename OpV0, typename OpV1>
void test_reduce_op_opset0_downgrade_pass()
{ {
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3}); const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1}); const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto sum_v1 = make_shared<op::v1::ReduceSum>(data, axes, true); const auto v1_node = make_shared<OpV1>(data, axes, true);
const auto result = make_shared<op::Result>(sum_v1); const auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto reshape_replacement_node = const auto reshape_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr(); const auto reshape_node = as_type_ptr<op::Reshape>(reshape_replacement_node);
const auto reshape = as_type_ptr<op::Reshape>(reshape_replacement_node); ASSERT_TRUE(reshape_node);
ASSERT_TRUE(reshape); EXPECT_EQ(reshape_node->output(0).get_element_type(), element::f32);
const auto sum_replace_node = EXPECT_EQ(reshape_node->output(0).get_shape(), (Shape{1, 1, 3}));
reshape_replacement_node->input(0).get_source_output().get_node_shared_ptr();
const auto sum_v0 = as_type_ptr<op::v0::Sum>(sum_replace_node); const auto op_replace_node = reshape_replacement_node->input_value(0).get_node_shared_ptr();
ASSERT_TRUE(sum_v0); const auto v0_node = as_type_ptr<OpV0>(op_replace_node);
ASSERT_TRUE(v0_node);
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes) template <typename OpV1>
void test_reduce_op_opset0_downgrade_pass_axes_not_constant()
{ {
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3}); const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto axes = make_shared<op::Parameter>(element::f32, Shape{1}); const auto axes = make_shared<op::Parameter>(element::f32, Shape{1});
const auto sum_v1 = make_shared<op::v1::ReduceSum>(data, axes, true); const auto v1_node = make_shared<OpV1>(data, axes, true);
const auto result = make_shared<op::Result>(sum_v1); const auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, axes});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
...@@ -87,10 +98,8 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes) ...@@ -87,10 +98,8 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), string("reduction axes are not constant (for keep_dims=true)"));
std::string("Unable to convert ReduceSum:v1 to Sum:v0 "
"if reduction axes are not constant (for keep_dims=true)"));
} }
catch (...) catch (...)
{ {
...@@ -98,13 +107,14 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes) ...@@ -98,13 +107,14 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_not_constant_axes)
} }
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static) template <typename OpV1>
void test_reduce_op_opset0_downgrade_pass_output_not_static()
{ {
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1}); const auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 1});
const auto sum_v1 = make_shared<op::v1::ReduceSum>(data, axes, true); const auto v1_node = make_shared<OpV1>(data, axes, true);
const auto result = make_shared<op::Result>(sum_v1); const auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
...@@ -116,9 +126,7 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static) ...@@ -116,9 +126,7 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(), string("output shape is dynamic (for keep_dims=true)"));
std::string("Unable to convert ReduceSum:v1 to Sum:v0 "
"if output shape is dynamic (for keep_dims=true)"));
} }
catch (...) catch (...)
{ {
...@@ -126,40 +134,164 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static) ...@@ -126,40 +134,164 @@ TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static)
} }
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_out_shape_if_keep_dims) template <typename OpV1>
void test_reduce_op_opset0_downgrade_pass_out_shape_if_keep_dims()
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5}); auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2}); auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = true; auto keep_dims = true;
auto reduce_sum_v1 = make_shared<op::v1::ReduceSum>(arg, axes, keep_dims); auto v1_node = make_shared<OpV1>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_sum_v1); const auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto replacement_node = const auto replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3, 1, 1})); ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3, 1, 1}));
} }
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_out_shape_if_not_keep_dims) template <typename OpV1>
void test_reduce_op_opset0_downgrade_pass_out_shape_if_not_keep_dims()
{ {
auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5}); auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2}); auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto keep_dims = false; auto keep_dims = false;
auto reduce_sum_v1 = make_shared<op::v1::ReduceSum>(arg, axes, keep_dims); auto v1_node = make_shared<OpV1>(arg, axes, keep_dims);
const auto result = make_shared<op::Result>(reduce_sum_v1); const auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg}); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f); pass_manager.run_passes(f);
const auto replacement_node = const auto replacement_node = f->get_result()->input_value(0).get_node_shared_ptr();
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3})); ASSERT_TRUE(replacement_node->get_output_partial_shape(0).compatible(PartialShape{3}));
} }
//------------------------------------------------------------------------------
//
// Test Cases
//
//------------------------------------------------------------------------------
TEST(opset_transform, opset1_reduce_sum_upgrade_pass)
{
test_reduce_op_opset1_upgrade_pass<op::Sum, op::v1::ReduceSum>();
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass)
{
test_reduce_op_opset0_downgrade_pass<op::v0::Sum, op::v1::ReduceSum>();
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_axes_not_constant_axes)
{
test_reduce_op_opset0_downgrade_pass_axes_not_constant<op::v1::ReduceSum>();
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_output_not_static)
{
test_reduce_op_opset0_downgrade_pass_output_not_static<op::v1::ReduceSum>();
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_out_shape_if_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_keep_dims<op::v1::ReduceSum>();
}
TEST(opset_transform, opset0_reduce_sum_downgrade_pass_out_shape_if_not_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_not_keep_dims<op::v1::ReduceSum>();
}
TEST(opset_transform, opset1_reduce_prod_upgrade_pass)
{
test_reduce_op_opset1_upgrade_pass<op::Product, op::v1::ReduceProd>();
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass)
{
test_reduce_op_opset0_downgrade_pass<op::v0::Product, op::v1::ReduceProd>();
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_axes_not_constant_axes)
{
test_reduce_op_opset0_downgrade_pass_axes_not_constant<op::v1::ReduceProd>();
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_output_not_static)
{
test_reduce_op_opset0_downgrade_pass_output_not_static<op::v1::ReduceProd>();
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_keep_dims<op::v1::ReduceProd>();
}
TEST(opset_transform, opset0_reduce_prod_downgrade_pass_out_shape_if_not_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_not_keep_dims<op::v1::ReduceProd>();
}
TEST(opset_transform, opset1_reduce_max_upgrade_pass)
{
test_reduce_op_opset1_upgrade_pass<op::Max, op::v1::ReduceMax>();
}
TEST(opset_transform, opset0_reduce_max_downgrade_pass)
{
test_reduce_op_opset0_downgrade_pass<op::v0::Max, op::v1::ReduceMax>();
}
TEST(opset_transform, opset0_reduce_max_downgrade_pass_axes_not_constant_axes)
{
test_reduce_op_opset0_downgrade_pass_axes_not_constant<op::v1::ReduceMax>();
}
TEST(opset_transform, opset0_reduce_max_downgrade_pass_output_not_static)
{
test_reduce_op_opset0_downgrade_pass_output_not_static<op::v1::ReduceMax>();
}
TEST(opset_transform, opset0_reduce_max_downgrade_pass_out_shape_if_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_keep_dims<op::v1::ReduceMax>();
}
TEST(opset_transform, opset0_reduce_max_downgrade_pass_out_shape_if_not_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_not_keep_dims<op::v1::ReduceMax>();
}
TEST(opset_transform, opset1_reduce_min_upgrade_pass)
{
test_reduce_op_opset1_upgrade_pass<op::Min, op::v1::ReduceMin>();
}
TEST(opset_transform, opset0_reduce_min_downgrade_pass)
{
test_reduce_op_opset0_downgrade_pass<op::v0::Min, op::v1::ReduceMin>();
}
TEST(opset_transform, opset0_reduce_min_downgrade_pass_axes_not_constant_axes)
{
test_reduce_op_opset0_downgrade_pass_axes_not_constant<op::v1::ReduceMin>();
}
TEST(opset_transform, opset0_reduce_min_downgrade_pass_output_not_static)
{
test_reduce_op_opset0_downgrade_pass_output_not_static<op::v1::ReduceMin>();
}
TEST(opset_transform, opset0_reduce_min_downgrade_pass_out_shape_if_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_keep_dims<op::v1::ReduceMin>();
}
TEST(opset_transform, opset0_reduce_min_downgrade_pass_out_shape_if_not_keep_dims)
{
test_reduce_op_opset0_downgrade_pass_out_shape_if_not_keep_dims<op::v1::ReduceMin>();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment