Commit 97a44e27 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add support for ArgMin and ArgMax to ZeroDimTensorElimination (#3022)

parent c555b36a
......@@ -37,3 +37,10 @@ shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) cons
check_new_args_count(this, new_args);
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
}
std::shared_ptr<Node> op::ArgMax::get_default_value() const
{
// Choice of value here is arbitrary, because validation should be rejecting cases where the
// axis of reduction has size zero.
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
......@@ -42,6 +42,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
}
......@@ -36,3 +36,10 @@ shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) cons
check_new_args_count(this, new_args);
return make_shared<ArgMin>(new_args.at(0), m_axis, this->get_element_type());
}
std::shared_ptr<Node> op::ArgMin::get_default_value() const
{
// Choice of value here is arbitrary, because validation should be rejecting cases where the
// axis of reduction has size zero.
return ngraph::make_constant_from_string("0", get_element_type(), get_shape());
}
......@@ -43,6 +43,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node> get_default_value() const override;
};
}
}
......@@ -72,6 +72,7 @@ void op::util::IndexReduction::set_index_element_type(const element::Type& index
void op::util::IndexReduction::validate_and_infer_types()
{
// TODO(amprocte): Should reject if size of reduction axis is zero.
const PartialShape& arg_shape = get_input_partial_shape(0);
Rank rank = arg_shape.rank();
......
......@@ -192,6 +192,38 @@ TEST(zero_dim_tensor_elimination, zero_const_slice)
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
}
TEST(zero_dim_tensor_elimination, zero_argmax)
{
auto A = std::make_shared<op::Parameter>(element::f32, Shape{0, 2, 3});
auto argmax = make_shared<op::ArgMax>(A, 1, element::i32);
auto f = std::make_shared<Function>(NodeVector{argmax}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("zero_argmax_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("zero_argmax_after.png");
EXPECT_EQ(count_ops_of_type<op::ArgMax>(f), 1);
pass_manager.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::ArgMax>(f), 0);
EXPECT_EQ(f->get_results().at(0)->get_shape(), (Shape{0, 3}));
}
TEST(zero_dim_tensor_elimination, zero_argmin)
{
auto A = std::make_shared<op::Parameter>(element::f32, Shape{0, 2, 3});
auto argmin = make_shared<op::ArgMin>(A, 1, element::i32);
auto f = std::make_shared<Function>(NodeVector{argmin}, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("zero_argmin_before.png");
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<pass::VisualizeTree>("zero_argmin_after.png");
EXPECT_EQ(count_ops_of_type<op::ArgMin>(f), 1);
pass_manager.run_passes(f);
EXPECT_EQ(count_ops_of_type<op::ArgMin>(f), 0);
EXPECT_EQ(f->get_results().at(0)->get_shape(), (Shape{0, 3}));
}
TEST(zero_dim_tensor_elimination, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ZeroDimTensorElimination>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment