Unverified Commit 1ebf4e6a authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

move is_unreachable to ngrapH_util.cpp (#1144)

parent f15877e2
...@@ -434,3 +434,32 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant) ...@@ -434,3 +434,32 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
auto result_bool = is_equal_to_const_value("1", reduce_constant); auto result_bool = is_equal_to_const_value("1", reduce_constant);
return result_bool; return result_bool;
} }
bool ngraph::is_used(std::shared_ptr<ngraph::Node> node)
{
std::unordered_set<std::shared_ptr<ngraph::Node>> instances_seen;
std::deque<std::shared_ptr<ngraph::Node>> stack;
stack.push_front(node);
while (stack.size() > 0)
{
std::shared_ptr<ngraph::Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
{
return true;
}
instances_seen.insert(n);
}
stack.pop_front();
for (auto arg : n->get_users())
{
if (instances_seen.count(arg) == 0)
{
stack.push_front(arg);
}
}
}
return false;
}
...@@ -131,4 +131,6 @@ namespace ngraph ...@@ -131,4 +131,6 @@ namespace ngraph
bool is_zero(std::shared_ptr<Node> reduce_constant); bool is_zero(std::shared_ptr<Node> reduce_constant);
bool is_one(std::shared_ptr<Node> reduce_constant); bool is_one(std::shared_ptr<Node> reduce_constant);
bool is_used(std::shared_ptr<Node> node);
} }
...@@ -319,35 +319,6 @@ static std::shared_ptr<ngraph::Node> ...@@ -319,35 +319,6 @@ static std::shared_ptr<ngraph::Node>
} }
} }
static bool is_unreachable(std::shared_ptr<ngraph::Node> node)
{
std::unordered_set<std::shared_ptr<ngraph::Node>> instances_seen;
std::deque<std::shared_ptr<ngraph::Node>> stack;
stack.push_front(node);
while (stack.size() > 0)
{
std::shared_ptr<ngraph::Node> n = stack.front();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
{
return false;
}
instances_seen.insert(n);
}
stack.pop_front();
for (auto arg : n->get_users())
{
if (instances_seen.count(arg) == 0)
{
stack.push_front(arg);
}
}
}
return true;
}
void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{ {
auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100}); auto ht_1 = std::make_shared<pattern::op::Label>(element::f32, Shape{32, 100});
...@@ -568,7 +539,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -568,7 +539,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
{ {
if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) == if (std::find(lstm_nodes.begin(), lstm_nodes.end(), goe0_user) ==
lstm_nodes.end() && lstm_nodes.end() &&
!is_unreachable(goe0_user)) ngraph::is_used(goe0_user))
{ {
lstm_goe0_user.insert(goe0_user); lstm_goe0_user.insert(goe0_user);
map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index]; map_goe_to_lstm_slices[goe_0] = ht_slice_per_timestep[index];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment