Commit b8b66de6 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Mmahbubu/provenance fix grpah replace (#3624)

* Fixes bug in provenenace for subgraph replacement

* Updates unit tests for the provenance algorithm fix

* Changes provnance set to ordered set for better consistency in iteration order.
parent 8295b844
...@@ -153,14 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target, ...@@ -153,14 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
{ {
auto common_args = ngraph::find_common_args(target, replacement); auto common_args = ngraph::find_common_args(target, replacement);
auto set_replacement_prov = [replacement](std::shared_ptr<Node> node) { std::set<string> removed_subgraph_tags;
replacement->merge_provenance_tags_from(node);
auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
for (auto tag : node->get_provenance_tags())
{
removed_subgraph_tags.insert(tag);
}
}; };
traverse_nodes({target}, set_replacement_prov, false, common_args); traverse_nodes({target}, set_replacement_prov, false, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);
auto set_prov_new_nodes = [replacement](std::shared_ptr<Node> node) { auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->merge_provenance_tags_from(replacement); node->add_provenance_tags(removed_subgraph_tags);
}; };
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args); traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
......
...@@ -318,7 +318,7 @@ void Node::set_placement_index(size_t placement) ...@@ -318,7 +318,7 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement; m_placement_index = placement;
} }
const std::unordered_set<std::string>& Node::get_provenance_tags() const const std::set<std::string>& Node::get_provenance_tags() const
{ {
return m_provenance_tags; return m_provenance_tags;
} }
...@@ -328,6 +328,14 @@ void Node::add_provenance_tag(const std::string& tag) ...@@ -328,6 +328,14 @@ void Node::add_provenance_tag(const std::string& tag)
m_provenance_tags.insert(tag); m_provenance_tags.insert(tag);
} }
void Node::add_provenance_tags(const std::set<std::string>& tag_set)
{
for (auto tag : tag_set)
{
add_provenance_tag(tag);
}
}
void Node::remove_provenance_tag(const std::string& tag) void Node::remove_provenance_tag(const std::string& tag)
{ {
m_provenance_tags.erase(tag); m_provenance_tags.erase(tag);
......
...@@ -369,8 +369,9 @@ namespace ngraph ...@@ -369,8 +369,9 @@ namespace ngraph
/// Set device placement /// Set device placement
void set_placement_index(size_t placement); void set_placement_index(size_t placement);
const std::unordered_set<std::string>& get_provenance_tags() const; const std::set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag); void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag); void remove_provenance_tag(const std::string& tag);
// to be used when nodes are replaced // to be used when nodes are replaced
...@@ -426,7 +427,7 @@ namespace ngraph ...@@ -426,7 +427,7 @@ namespace ngraph
std::string m_unique_name; std::string m_unique_name;
NGRAPH_API NGRAPH_API
static std::atomic<size_t> m_next_instance_id; static std::atomic<size_t> m_next_instance_id;
std::unordered_set<std::string> m_provenance_tags; std::set<std::string> m_provenance_tags;
std::deque<descriptor::Input> m_inputs; std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
...@@ -29,7 +29,7 @@ using namespace std; ...@@ -29,7 +29,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
using ::testing::Return; using ::testing::Return;
using ProvSet = std::unordered_set<std::string>; using ProvSet = std::set<std::string>;
TEST(provenance, provenance) TEST(provenance, provenance)
{ {
...@@ -231,7 +231,7 @@ TEST(provenance, provenance) ...@@ -231,7 +231,7 @@ TEST(provenance, provenance)
// //
// A{tag_a} B{tag_b} // A{tag_a} B{tag_b}
// | | // | |
// E{tag_c, tag_d} | // E{tag_c} |
// | | // | |
// D{tag_c, tag_d} // D{tag_c, tag_d}
// //
...@@ -258,7 +258,7 @@ TEST(provenance, provenance) ...@@ -258,7 +258,7 @@ TEST(provenance, provenance)
replace_node(c, d); replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"})); EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"})); EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c"}));
} }
// //
...@@ -310,7 +310,7 @@ TEST(provenance, provenance) ...@@ -310,7 +310,7 @@ TEST(provenance, provenance)
replace_node(c, d); replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"})); EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"})); EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"}));
} }
// //
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment