Commit 5f83ff57 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Fixes provenance bug causing extra tags to be added during node replacement (#2950)

parent a02bfa42
...@@ -135,6 +135,36 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p, ...@@ -135,6 +135,36 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
} }
} }
NodeVector ngraph::find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
std::unordered_set<std::shared_ptr<Node>> target_args;
auto compute_target_args = [&target_args](const std::shared_ptr<Node> node) {
target_args.insert(node);
};
traverse_nodes({target}, compute_target_args, false, NodeVector{});
std::unordered_set<std::shared_ptr<Node>> replacement_args;
auto compute_replacement_args = [&replacement_args](const std::shared_ptr<Node> node) {
replacement_args.insert(node);
};
traverse_nodes({replacement}, compute_replacement_args, false, NodeVector{});
NodeVector common_args;
for (auto e : target_args)
{
if (replacement_args.count(e) > 0)
{
common_args.push_back(e);
}
}
return common_args;
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{ {
if (target->is_output()) if (target->is_output())
...@@ -156,7 +186,8 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -156,7 +186,8 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
replacement->merge_provenance_tags_from(node); replacement->merge_provenance_tags_from(node);
}; };
traverse_nodes({target}, set_replacement_prov, false, replacement->get_arguments()); traverse_nodes(
{target}, set_replacement_prov, false, ngraph::find_common_args(target, replacement));
} }
// For each of target's output O with replacement output O_rep: // For each of target's output O with replacement output O_rep:
......
...@@ -75,6 +75,8 @@ namespace ngraph ...@@ -75,6 +75,8 @@ namespace ngraph
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
template <typename T> template <typename T>
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes, std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
bool include_control_deps = false) bool include_control_deps = false)
......
...@@ -209,4 +209,54 @@ TEST(provenance, provenance) ...@@ -209,4 +209,54 @@ TEST(provenance, provenance)
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"})); EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
} }
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
//
// Replacement:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_e} |
// | |
// C -> D{tag_d}
//
//
// After:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_e} |
// | |
// D{tag_c, tag_d}
//
// Comment:
// * D is the replacement root replacing C and creating a new argument node E
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto e = make_shared<op::Subtract>(a, x);
auto d = make_shared<op::Subtract>(e, b);
d->add_provenance_tag("tag_d");
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment