Commit b8b66de6 authored by Mohammad Mahbubuzzaman's avatar Mohammad Mahbubuzzaman Committed by Scott Cyphers

Mmahbubu/provenance fix grpah replace (#3624)

* Fixes bug in provenenace for subgraph replacement

* Updates unit tests for the provenance algorithm fix

* Changes provnance set to ordered set for better consistency in iteration order.
parent 8295b844
......@@ -153,14 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
{
auto common_args = ngraph::find_common_args(target, replacement);
auto set_replacement_prov = [replacement](std::shared_ptr<Node> node) {
replacement->merge_provenance_tags_from(node);
std::set<string> removed_subgraph_tags;
auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
for (auto tag : node->get_provenance_tags())
{
removed_subgraph_tags.insert(tag);
}
};
traverse_nodes({target}, set_replacement_prov, false, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);
auto set_prov_new_nodes = [replacement](std::shared_ptr<Node> node) {
node->merge_provenance_tags_from(replacement);
auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->add_provenance_tags(removed_subgraph_tags);
};
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
......
......@@ -318,7 +318,7 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement;
}
const std::unordered_set<std::string>& Node::get_provenance_tags() const
const std::set<std::string>& Node::get_provenance_tags() const
{
return m_provenance_tags;
}
......@@ -328,6 +328,14 @@ void Node::add_provenance_tag(const std::string& tag)
m_provenance_tags.insert(tag);
}
void Node::add_provenance_tags(const std::set<std::string>& tag_set)
{
for (auto tag : tag_set)
{
add_provenance_tag(tag);
}
}
void Node::remove_provenance_tag(const std::string& tag)
{
m_provenance_tags.erase(tag);
......
......@@ -369,8 +369,9 @@ namespace ngraph
/// Set device placement
void set_placement_index(size_t placement);
const std::unordered_set<std::string>& get_provenance_tags() const;
const std::set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag);
// to be used when nodes are replaced
......@@ -426,7 +427,7 @@ namespace ngraph
std::string m_unique_name;
NGRAPH_API
static std::atomic<size_t> m_next_instance_id;
std::unordered_set<std::string> m_provenance_tags;
std::set<std::string> m_provenance_tags;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
......@@ -29,7 +29,7 @@ using namespace std;
using namespace ngraph;
using ::testing::Return;
using ProvSet = std::unordered_set<std::string>;
using ProvSet = std::set<std::string>;
TEST(provenance, provenance)
{
......@@ -231,7 +231,7 @@ TEST(provenance, provenance)
//
// A{tag_a} B{tag_b}
// | |
// E{tag_c, tag_d} |
// E{tag_c} |
// | |
// D{tag_c, tag_d}
//
......@@ -258,7 +258,7 @@ TEST(provenance, provenance)
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c"}));
}
//
......@@ -310,7 +310,7 @@ TEST(provenance, provenance)
replace_node(c, d);
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"}));
}
//
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment