Commit 6c8284a3 authored by Adam Straw's avatar Adam Straw Committed by Robert Kimball

add logic in replace_node for provenance propagation (#2703)

* Add some unit tests for provenance with node replacement

* One more test

* add provenace propagation
parent 3f017a1e
...@@ -70,9 +70,10 @@ void ngraph::traverse_nodes(const Function* p, ...@@ -70,9 +70,10 @@ void ngraph::traverse_nodes(const Function* p,
// directly from the result nodes, not from function parameters. // directly from the result nodes, not from function parameters.
void ngraph::traverse_nodes(const NodeVector& io_nodes, void ngraph::traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f, std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps) bool include_control_deps,
NodeVector stop_nodes)
{ {
std::unordered_set<std::shared_ptr<Node>> instances_seen; std::unordered_set<std::shared_ptr<Node>> instances_seen(stop_nodes.begin(), stop_nodes.end());
std::deque<std::shared_ptr<Node>> stack; std::deque<std::shared_ptr<Node>> stack;
for (auto r : io_nodes) for (auto r : io_nodes)
...@@ -152,6 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -152,6 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
// Fix input/output descriptors // Fix input/output descriptors
assert(target->get_outputs().size() == replacement->get_outputs().size()); assert(target->get_outputs().size() == replacement->get_outputs().size());
auto set_replacement_prov = [replacement](std::shared_ptr<Node> node) {
replacement->merge_provenance_tags_from(node);
};
traverse_nodes({target}, set_replacement_prov, false, replacement->get_arguments());
auto propagate_replacement_prov = [replacement](std::shared_ptr<Node> node) {
if (is_post_dominated(node.get(), replacement.get()))
{
node->merge_provenance_tags_from(replacement);
}
};
traverse_nodes({replacement}, propagate_replacement_prov, false);
// For each of target's output O with replacement output O_rep: // For each of target's output O with replacement output O_rep:
// For each O's connected downstream input I: // For each O's connected downstream input I:
// Change I's connected upstream output to O_rep // Change I's connected upstream output to O_rep
...@@ -165,7 +180,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re ...@@ -165,7 +180,6 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
input->replace_output(replacement->get_outputs().at(i)); input->replace_output(replacement->get_outputs().at(i));
} }
} }
replacement->merge_provenance_tags_from(target);
} }
// Check if all paths from X to a result go through Y // Check if all paths from X to a result go through Y
......
...@@ -51,7 +51,8 @@ namespace ngraph ...@@ -51,7 +51,8 @@ namespace ngraph
void traverse_nodes(const NodeVector& io_nodes, void traverse_nodes(const NodeVector& io_nodes,
std::function<void(std::shared_ptr<Node>)> f, std::function<void(std::shared_ptr<Node>)> f,
bool include_control_deps); bool include_control_deps,
NodeVector stop_nodes = {});
void traverse_functions(std::shared_ptr<Function> p, void traverse_functions(std::shared_ptr<Function> p,
std::function<void(std::shared_ptr<Function>)> f); std::function<void(std::shared_ptr<Function>)> f);
......
...@@ -57,6 +57,7 @@ set(SRC ...@@ -57,6 +57,7 @@ set(SRC
pass_memory_layout.cpp pass_memory_layout.cpp
pass_shape_specialization.cpp pass_shape_specialization.cpp
pattern.cpp pattern.cpp
provenance.cpp
reshape_elimination.cpp reshape_elimination.cpp
reshape_sinking.cpp reshape_sinking.cpp
serialize.cpp serialize.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using ProvSet = std::unordered_set<std::string>;
TEST(provenance, provenance)
{
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement:
//
// A{tag_a} B{tag_b}
// | |
// C := D{}
//
// After:
//
// A{tag_a} B{tag_b}
// | |
// D{tag_c}
//
// Comment:
// * D is the replacement root, and its insertion kills C. We should not, however, consider
// A and B to be killed, because they are not post-dominated by D until after C is cut out
// of the graph.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto new_c = make_shared<op::Subtract>(a, b);
replace_node(c, new_c);
EXPECT_EQ(new_c->get_provenance_tags(), ProvSet{"tag_c"});
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement:
//
//
//
// A{tag_a} B{tag_b}
// | |
// C -> D{tag_d}
//
// After:
//
// A{tag_a} B{tag_b}
// | |
// D{tag_c,tag_d}
//
// Comment:
// * D is the replacement root, and its insertion kills C. We should not, however, consider
// A and B to be killed, because they are not post-dominated by D until after C is cut out
// of the graph.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto d = make_shared<op::Subtract>(a, b);
d->add_provenance_tag("tag_d");
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement:
//
// C -> D{tag_d}
//
// After:
//
// D{tag_a,tag_b,tag_c,tag_d}
//
// Comment:
// * D is the replacement root, and its insertion kills A, B, and C.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto d = make_zero(element::i32, Shape{2, 3, 4});
d->add_provenance_tag("tag_d");
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c", "tag_d"}));
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement:
//
// C -> D{}
//
// After:
//
// D{tag_a,tag_b,tag_c}
//
// Comment:
// * D is the replacement root, and its insertion kills A, B, and C.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto d = make_zero(element::i32, Shape{2, 3, 4});
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement:
//
// D{}
// |
// C -> E{}
//
// After:
//
// D{tag_a,tag_b,tag_c}
// |
// E{tag_a,tag_b,tag_c}
//
// Comment:
// * E is the replacement root, and its insertion kills A, B, and C.
// * D is post-dominated by E, so the tags inherited by E should also be taken on by D.
//
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto f = make_shared<Function>(c, ParameterVector{x, y});
auto d = make_zero(element::i32, Shape{2, 3, 4});
auto e = make_shared<op::Negative>(d);
replace_node(c, e);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
}
//
// Before:
//
// A{tag_a} B{tag_b}
// | |
// C{tag_c}
//
// Replacement: C is replaced with G, where:
//
// D{} E{}
// \ / \
// G{} F{}
//
// After:
//
// D{tag_a,tag_b,tag_c} E{}
// \ / \
// G{tag_a,tag_b,tag_c} F{}
//
// Comment:
// * G is the replacement root, and its insertion kills A, B, and C.
// * D is post-dominated by G, but E and F are not. Therefore D should take on the subsumed
// tags, but E and F should not.
//
/*
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
a->add_provenance_tag("tag_a");
auto b = make_shared<op::Multiply>(y, x);
b->add_provenance_tag("tag_b");
auto c = make_shared<op::Subtract>(a, b);
c->add_provenance_tag("tag_c");
auto func = make_shared<Function>(c, ParameterVector{x, y});
auto d = make_zero(element::i32, Shape{2, 3, 4});
auto e = make_zero(element::i32, Shape{2, 3, 4});
auto f = make_shared<op::Negative>(e);
auto g = make_shared<op::Add>(d, e);
replace_node(c, g);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{}));
EXPECT_EQ(f->get_provenance_tags(), (ProvSet{}));
EXPECT_EQ(g->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
}
*/
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment