Commit e4d5355b authored by Matthew Brookhart's avatar Matthew Brookhart Committed by Scott Cyphers

Add OneHot to CSE to optimize PP NNP CPU fallbacks (#2791)

parent 3bbc26ac
......@@ -43,6 +43,7 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/relu.hpp"
......@@ -129,6 +130,17 @@ static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
}
static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_one_hot for " << a->get_name() << " and " << b->get_name();
auto one_hot_a = static_pointer_cast<ngraph::op::OneHot>(a);
auto one_hot_b = static_pointer_cast<ngraph::op::OneHot>(b);
return (a->get_argument(0) == b->get_argument(0)) &&
(one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) &&
(a->get_shape() == b->get_shape());
}
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
initialize_ops_to_cse_handlers()
{
......@@ -145,6 +157,7 @@ static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node
{TI(op::Floor), cse_unarywise},
{TI(op::Log), cse_unarywise},
{TI(op::Negative), cse_unarywise},
{TI(op::OneHot), cse_one_hot},
{TI(op::Relu), cse_unarywise},
{TI(op::Sigmoid), cse_unarywise},
{TI(op::Sign), cse_unarywise},
......
......@@ -306,3 +306,31 @@ TEST(CSE, constant)
ASSERT_NE(abs0->get_argument(0), absf->get_argument(0));
ASSERT_NE(abs111->get_argument(0), abs112->get_argument(0));
}
TEST(CSE, one_hot)
{
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
{
Shape param_shape{8};
Shape out_shape{8, 16};
auto A = std::make_shared<op::Parameter>(element::i32, param_shape);
auto onehot1 = std::make_shared<op::OneHot>(A, out_shape, 1);
auto onehot2 = std::make_shared<op::OneHot>(A, out_shape, 1);
auto f = std::make_shared<Function>(NodeVector{onehot1, onehot2}, ParameterVector{A});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
{
Shape param_shape{8, 1};
Shape out_shape{8, 16};
auto A = std::make_shared<op::Parameter>(element::i32, param_shape);
auto reshape1 = std::make_shared<op::Reshape>(A, AxisVector{0, 1}, Shape{8});
auto reshape2 = std::make_shared<op::Reshape>(A, AxisVector{0, 1}, Shape{8});
auto onehot1 = std::make_shared<op::OneHot>(reshape1, out_shape, 1);
auto onehot2 = std::make_shared<op::OneHot>(reshape2, out_shape, 1);
auto f = std::make_shared<Function>(NodeVector{onehot1, onehot2}, ParameterVector{A});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment