Commit c3f16bb9 authored by Adam Procter's avatar Adam Procter

Add support for logical reduction ops

parent 8eed208b
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp" #include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -64,7 +66,9 @@ ...@@ -64,7 +66,9 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp" #include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
...@@ -1722,6 +1726,75 @@ void pass::ConstantFolding::construct_constant_arithmetic_reduction() ...@@ -1722,6 +1726,75 @@ void pass::ConstantFolding::construct_constant_arithmetic_reduction()
all_pass_property_off); all_pass_property_off);
} }
static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
shared_ptr<Node> reduction_node)
{
vector<char> out_vec(shape_size(reduction_node->get_shape()));
if (auto p = dynamic_pointer_cast<::ngraph::op::All>(reduction_node))
{
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
p->get_reduction_axes());
}
else if (auto p = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node))
{
runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
p->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_logical_reduction must be consistent with those "
"matched in construct_constant_logical_reduction");
}
return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
}
void pass::ConstantFolding::construct_constant_logical_reduction()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto is_supported_reduction = [](std::shared_ptr<Node> n) {
return (pattern::has_class<::ngraph::op::All>()(n) ||
pattern::has_class<::ngraph::op::Any>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_logical_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_logical_reduction_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
replace_node(reduction_match,
fold_constant_logical_reduction(constant_match, reduction_match));
return true;
};
auto logical_reduction_matcher =
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
this->add_matcher(
logical_reduction_matcher, constant_logical_reduction_callback, all_pass_property_off);
}
template <typename T> template <typename T>
static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op::Concat>& concat) static shared_ptr<op::Constant> fold_constant_concat_helper(const shared_ptr<op::Concat>& concat)
{ {
......
...@@ -43,6 +43,7 @@ public: ...@@ -43,6 +43,7 @@ public:
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
ARITHMETIC_REDUCTION, ARITHMETIC_REDUCTION,
LOGICAL_REDUCTION,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -66,6 +67,7 @@ public: ...@@ -66,6 +67,7 @@ public:
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_arithmetic_reduction(); construct_constant_arithmetic_reduction();
construct_constant_logical_reduction();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -98,6 +100,9 @@ public: ...@@ -98,6 +100,9 @@ public:
case CFTransformations::ARITHMETIC_REDUCTION: case CFTransformations::ARITHMETIC_REDUCTION:
construct_constant_arithmetic_reduction(); construct_constant_arithmetic_reduction();
break; break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -120,6 +125,7 @@ private: ...@@ -120,6 +125,7 @@ private:
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_arithmetic_reduction(); void construct_constant_arithmetic_reduction();
void construct_constant_logical_reduction();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -486,6 +486,58 @@ TEST(constant_folding, const_min) ...@@ -486,6 +486,58 @@ TEST(constant_folding, const_min)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment