Commit edd377ce authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Option to turn off provenance tag prop dynamically (#3575)

* Adds option to switch off provenance tag propagation in runtime

* Fixes style issues.
parent 0819d46a
...@@ -132,7 +132,9 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr ...@@ -132,7 +132,9 @@ NodeVector ngraph::find_common_args(std::shared_ptr<Node> node1, std::shared_ptr
return common_args; return common_args;
} }
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool disable_prov_tag_prop)
{ {
if (target->is_output()) if (target->is_output())
{ {
...@@ -147,7 +149,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -147,7 +149,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
// Fix input/output descriptors // Fix input/output descriptors
NGRAPH_CHECK(target->get_output_size() == replacement->get_output_size()); NGRAPH_CHECK(target->get_output_size() == replacement->get_output_size());
if (ngraph::get_provenance_enabled()) if (ngraph::get_provenance_enabled() && !disable_prov_tag_prop)
{ {
auto common_args = ngraph::find_common_args(target, replacement); auto common_args = ngraph::find_common_args(target, replacement);
......
...@@ -212,7 +212,9 @@ namespace ngraph ...@@ -212,7 +212,9 @@ namespace ngraph
/// auto new_N = N->copy_with_new_args(N->get_arguments()); /// auto new_N = N->copy_with_new_args(N->get_arguments());
/// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N); /// shared_ptr<Node> M = make_shared<SomeUnaryOp>(new_N);
/// replace_node(N, M); /// replace_node(N, M);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool disable_prov_tag_prop = false);
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement); NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
......
...@@ -312,4 +312,58 @@ TEST(provenance, provenance) ...@@ -312,4 +312,58 @@ TEST(provenance, provenance)
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"})); EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"})); EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"}));
} }
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
//
// Replacement:
//
// A{tag_a} B{tag_b}
// | |
// E{tag_e} |
// | |
// C -> D{tag_d}
//
//
// After:
//
// A{tag_a} B{tag_b}
// \ /
// E{tag_e} /
// \ /
// D{tag_d}
//
// Comment:
// * D is the replacement root replacing C and creating a new argument node E
// * This test checks the "disable_prov_tag_prop" flag which when set, disables
// provenance tag porpagation.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto e = make_shared<op::Subtract>(a, x);
e->add_provenance_tag("tag_e");
auto d = make_shared<op::Subtract>(e, b);
d->add_provenance_tag("tag_d");
replace_node(c, d, true);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_e"}));
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment