Commit bf7c6f68 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add Min/Max support to zero dim tensor elimination pass (#2719)

* add Min/Max support to zero dim tensor elimination

* fix infinity
parent 62cd2f68
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/max.hpp"
#include "ngraph/graph_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -30,3 +31,43 @@ shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
check_new_args_count(this, new_args);
return make_shared<Max>(new_args.at(0), m_reduction_axes);
}
shared_ptr<Node> op::Max::get_default_value() const
{
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean:
return make_constant_from_string("0", get_element_type(), get_shape());
case element::Type_t::bf16:
return make_constant_from_string("-INFINITY", get_element_type(), get_shape());
case element::Type_t::f32:
return make_constant_from_string("-INFINITY", get_element_type(), get_shape());
case element::Type_t::f64:
return make_constant_from_string("-INFINITY", get_element_type(), get_shape());
case element::Type_t::i8:
return make_constant_from_string(
to_string(numeric_limits<int8_t>::min()), get_element_type(), get_shape());
case element::Type_t::i16:
return make_constant_from_string(
to_string(numeric_limits<int16_t>::min()), get_element_type(), get_shape());
case element::Type_t::i32:
return make_constant_from_string(
to_string(numeric_limits<int32_t>::min()), get_element_type(), get_shape());
case element::Type_t::i64:
return make_constant_from_string(
to_string(numeric_limits<int64_t>::min()), get_element_type(), get_shape());
case element::Type_t::u8:
return make_constant_from_string(
to_string(numeric_limits<uint8_t>::min()), get_element_type(), get_shape());
case element::Type_t::u16:
return make_constant_from_string(
to_string(numeric_limits<uint16_t>::min()), get_element_type(), get_shape());
case element::Type_t::u32:
return make_constant_from_string(
to_string(numeric_limits<uint32_t>::min()), get_element_type(), get_shape());
case element::Type_t::u64:
return make_constant_from_string(
to_string(numeric_limits<uint64_t>::min()), get_element_type(), get_shape());
default: throw runtime_error("Max default value not defined for type");
}
}
......@@ -34,6 +34,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Max.
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
}
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/min.hpp"
#include "ngraph/graph_util.hpp"
using namespace std;
using namespace ngraph;
......@@ -30,3 +31,43 @@ shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
check_new_args_count(this, new_args);
return make_shared<Min>(new_args.at(0), m_reduction_axes);
}
shared_ptr<Node> op::Min::get_default_value() const
{
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean:
return make_constant_from_string("1", get_element_type(), get_shape());
case element::Type_t::bf16:
return make_constant_from_string("INFINITY", get_element_type(), get_shape());
case element::Type_t::f32:
return make_constant_from_string("INFINITY", get_element_type(), get_shape());
case element::Type_t::f64:
return make_constant_from_string("INFINITY", get_element_type(), get_shape());
case element::Type_t::i8:
return make_constant_from_string(
to_string(numeric_limits<int8_t>::max()), get_element_type(), get_shape());
case element::Type_t::i16:
return make_constant_from_string(
to_string(numeric_limits<int16_t>::max()), get_element_type(), get_shape());
case element::Type_t::i32:
return make_constant_from_string(
to_string(numeric_limits<int32_t>::max()), get_element_type(), get_shape());
case element::Type_t::i64:
return make_constant_from_string(
to_string(numeric_limits<int64_t>::max()), get_element_type(), get_shape());
case element::Type_t::u8:
return make_constant_from_string(
to_string(numeric_limits<uint8_t>::max()), get_element_type(), get_shape());
case element::Type_t::u16:
return make_constant_from_string(
to_string(numeric_limits<uint16_t>::max()), get_element_type(), get_shape());
case element::Type_t::u32:
return make_constant_from_string(
to_string(numeric_limits<uint32_t>::max()), get_element_type(), get_shape());
case element::Type_t::u64:
return make_constant_from_string(
to_string(numeric_limits<uint64_t>::max()), get_element_type(), get_shape());
default: throw runtime_error("Min default value not defined for type");
}
}
......@@ -34,6 +34,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for Min.
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
}
......@@ -41,18 +41,72 @@ TEST(zero_dim_tensor_elimination, zero_sum)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto sum = std::make_shared<op::Sum>(A, AxisSet{0});
auto abs_node = std::make_shared<op::Abs>(A);
auto sum_node = std::make_shared<op::Sum>(abs_node, AxisSet{0});
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{sum_node, constant}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_sum_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_sum_after.png");
EXPECT_EQ(count_ops_of_type<op::Sum>(f), 1);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Sum>(f), 0);
EXPECT_EQ(count_ops_of_type<op::Sum>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_product)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_node = std::make_shared<op::Abs>(A);
auto product_node = std::make_shared<op::Product>(abs_node, AxisSet{0});
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{product_node, constant}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("zero_product_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("zero_product_after.png");
EXPECT_EQ(count_ops_of_type<op::Product>(f), 1);
pass_manager.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::Product>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_min)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_node = std::make_shared<op::Abs>(A);
auto min_node = std::make_shared<op::Min>(abs_node, AxisSet{0});
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{min_node, constant}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("zero_min_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("zero_min_after.png");
EXPECT_EQ(count_ops_of_type<op::Min>(f), 1);
pass_manager.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::Min>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_max)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_node = std::make_shared<op::Abs>(A);
auto max_node = std::make_shared<op::Max>(abs_node, AxisSet{0});
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
auto f = std::make_shared<Function>(NodeVector{max_node, constant}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("zero_max_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("zero_max_after.png");
EXPECT_EQ(count_ops_of_type<op::Max>(f), 1);
pass_manager.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::Max>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_conv)
......@@ -68,11 +122,12 @@ TEST(zero_dim_tensor_elimination, zero_const_conv)
std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_conv_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_conv_after.png");
EXPECT_EQ(count_ops_of_type<op::Convolution>(f), 1);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Convolution>(f), 0);
EXPECT_EQ(count_ops_of_type<op::Convolution>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_avg_pool)
......@@ -87,11 +142,12 @@ TEST(zero_dim_tensor_elimination, zero_const_avg_pool)
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_avg_pool_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_avg_pool_after.png");
EXPECT_EQ(count_ops_of_type<op::AvgPool>(f), 1);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::AvgPool>(f), 0);
EXPECT_EQ(count_ops_of_type<op::AvgPool>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_const_pad)
......@@ -106,11 +162,12 @@ TEST(zero_dim_tensor_elimination, zero_const_pad)
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_pad_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_pad_after.png");
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
}
TEST(zero_dim_tensor_elimination, zero_const_slice)
......@@ -125,10 +182,12 @@ TEST(zero_dim_tensor_elimination, zero_const_slice)
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_slice_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.register_pass<pass::VisualizeTree>("zero_const_slice_after.png");
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment