Commit 83e7dba5 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

get_subgraph_outputs (towards checking that intermediate nodes in a matched graph not used) (#1207)

* get_subgraph_outputs

* simplify the condition
parent 33b54ce1
......@@ -465,6 +465,33 @@ bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
return result_bool;
}
NodeVector ngraph::get_subgraph_outputs(const NodeVector& nodes,
const NodeVector& exclusions,
bool ignore_unused)
{
std::set<shared_ptr<Node>> exclusions_set(exclusions.begin(), exclusions.end());
std::set<shared_ptr<Node>> nodes_set(nodes.begin(), nodes.end());
NodeVector outputs;
for (auto n : nodes)
{
if (exclusions_set.count(n) != 0)
{
continue;
}
for (auto u : n->get_users())
{
if (nodes_set.count(u) == 0 && (!ignore_unused || is_used(u.get())))
{
outputs.push_back(n);
}
}
}
return outputs;
}
bool ngraph::is_used(Node* node)
{
std::unordered_set<Node*> instances_seen;
......
......@@ -133,6 +133,10 @@ namespace ngraph
bool is_zero(std::shared_ptr<Node> reduce_constant);
NodeVector get_subgraph_outputs(const NodeVector& nodes,
const NodeVector& exclusions,
bool ignore_unused = false);
bool is_one(std::shared_ptr<Node> reduce_constant);
// Returns true if `node` is live in the graph i.e. a result op
......
......@@ -324,3 +324,44 @@ TEST(util, parse_string)
EXPECT_FLOAT_EQ(-numeric_limits<double>::infinity(), parse_string<double>("-INFINITY"));
EXPECT_TRUE(std::isnan(parse_string<double>("NaN")));
}
TEST(graph_util, get_subgraph_outputs_trivial_tests)
{
auto outputs = ngraph::get_subgraph_outputs(NodeVector{}, NodeVector{});
ASSERT_EQ(outputs.size(), 0);
Shape shape{};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto absn = make_shared<op::Abs>(A);
auto neg_absn = make_shared<op::Negative>(absn);
outputs = ngraph::get_subgraph_outputs(NodeVector{A}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{A}));
outputs = ngraph::get_subgraph_outputs(NodeVector{A}, NodeVector{A});
ASSERT_EQ(outputs, (NodeVector{}));
outputs = ngraph::get_subgraph_outputs(NodeVector{A, absn}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{absn}));
auto B = make_shared<op::Parameter>(element::f32, shape);
auto abs_b = make_shared<op::Abs>(B);
auto neg_b = make_shared<op::Negative>(B);
auto abs_b_neg = make_shared<op::Negative>(abs_b);
outputs = ngraph::get_subgraph_outputs(NodeVector{B, abs_b}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{B, abs_b}));
outputs = ngraph::get_subgraph_outputs(NodeVector{B, abs_b}, NodeVector{B});
ASSERT_EQ(outputs, (NodeVector{abs_b}));
outputs = ngraph::get_subgraph_outputs(NodeVector{B, abs_b, abs_b_neg}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{B}));
auto add_b = make_shared<op::Add>(neg_b, abs_b_neg);
outputs =
ngraph::get_subgraph_outputs(NodeVector{B, abs_b, neg_b, abs_b_neg, add_b}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{}));
//now add_b uses abs_b_neg
outputs = ngraph::get_subgraph_outputs(NodeVector{B, abs_b, abs_b_neg}, NodeVector{});
ASSERT_EQ(outputs, (NodeVector{B, abs_b_neg}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment