Commit 824d6144 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Provenance fix for new nodes during replace (#3202)

* Fixes pprovenance issue for replace_node() when new nodes are added by the replacement.

* Updates unit test comment and adds one more.
parent 34326357
......@@ -102,28 +102,28 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
}
}
NodeVector ngraph::find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr<Node> node2)
{
std::unordered_set<std::shared_ptr<Node>> target_args;
std::unordered_set<std::shared_ptr<Node>> node1_args;
auto compute_target_args = [&target_args](const std::shared_ptr<Node> node) {
target_args.insert(node);
auto compute_node1_args = [&node1_args](const std::shared_ptr<Node> node) {
node1_args.insert(node);
};
traverse_nodes({target}, compute_target_args, false, NodeVector{});
traverse_nodes({node1}, compute_node1_args, false, NodeVector{});
std::unordered_set<std::shared_ptr<Node>> replacement_args;
std::unordered_set<std::shared_ptr<Node>> node2_args;
auto compute_replacement_args = [&replacement_args](const std::shared_ptr<Node> node) {
replacement_args.insert(node);
auto compute_node2_args = [&node2_args](const std::shared_ptr<Node> node) {
node2_args.insert(node);
};
traverse_nodes({replacement}, compute_replacement_args, false, NodeVector{});
traverse_nodes({node2}, compute_node2_args, false, NodeVector{});
NodeVector common_args;
for (auto e : target_args)
for (auto e : node1_args)
{
if (replacement_args.count(e) > 0)
if (node2_args.count(e) > 0)
{
common_args.push_back(e);
}
......@@ -149,12 +149,19 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
if (ngraph::get_provenance_enabled())
{
auto common_args = ngraph::find_common_args(target, replacement);
auto set_replacement_prov = [replacement](std::shared_ptr<Node> node) {
replacement->merge_provenance_tags_from(node);
};
traverse_nodes(
{target}, set_replacement_prov, false, ngraph::find_common_args(target, replacement));
traverse_nodes({target}, set_replacement_prov, false, common_args);
auto set_prov_new_nodes = [replacement](std::shared_ptr<Node> node) {
node->merge_provenance_tags_from(replacement);
};
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
}
// For each of target's output O with replacement output O_rep:
......
......@@ -221,19 +221,70 @@ TEST(provenance, provenance)
// Replacement:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_e} |
// | |
// C -> D{tag_d}
// | |
// E{} |
// | |
// C -> D{tag_d}
//
//
// After:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_c, tag_d} |
// | |
// D{tag_c, tag_d}
//
// Comment:
// * D is the replacement root replacing C and creating a new argument node E
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto e = make_shared<op::Subtract>(a, x);
auto d = make_shared<op::Subtract>(e, b);
d->add_provenance_tag("tag_d");
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
//
// Replacement:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_e} |
// | |
// D{tag_c, tag_d}
// C -> D{tag_d}
//
//
// After:
//
// A{tag_a} B{tag_b}
// \ /
// E{tag_c, tag_d, tag_e} /
// \ /
// D{tag_c, tag_d}
//
// Comment:
// * D is the replacement root replacing C and creating a new argument node E
......@@ -252,11 +303,13 @@ TEST(provenance, provenance)
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto e = make_shared<op::Subtract>(a, x);
e->add_provenance_tag("tag_e");
auto d = make_shared<op::Subtract>(e, b);
d->add_provenance_tag("tag_d");
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"}));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment