Unverified Commit 3409bda8 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by GitHub

Add provenance tags to decompose ops pass (#4213)

* Add provenance tag in fused decomposition pass

* Add and update test for decomposition tag

* Style apply

* Style apply
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 328c328e
......@@ -16,6 +16,7 @@
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
using namespace ngraph;
......@@ -36,20 +37,24 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
// Op supported by backend. Do not decompose
return modified;
}
// Capture the input values as a base for provenance
OutputVector base_input_values;
for (auto value : node->input_values())
{
base_input_values.push_back(value);
}
auto subgraph_outputs = node->decompose_op();
// Transfer the new provenance tags to the newly created ops
auto provenance_tags = node->get_provenance_tags();
for (auto subgraph : subgraph_outputs)
if (ngraph::get_provenance_enabled())
{
subgraph->add_provenance_tags_above(base_input_values, provenance_tags);
// Capture the input values as an edge for provenance
auto base_input_values = node->input_values();
auto provenance_tags = node->get_provenance_tags();
const std::string tag = "<Decomposed from " + std::string(node->get_type_name()) + ">";
provenance_tags.insert(tag);
// Transfer the new provenance tags to the newly created ops
for (auto output_node : subgraph_outputs)
{
output_node->add_provenance_tags_above(base_input_values, provenance_tags);
}
}
// Run recursively untill no more fused ops
// Run recursively until no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments());
for (auto subgraph_node : subgraph)
{
......@@ -61,7 +66,6 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
set<descriptor::Input*> fop_users{begin(node->get_outputs().at(i).get_inputs()),
end(node->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
......
......@@ -352,32 +352,61 @@ TEST(provenance, builder)
}
}
TEST(provenance, fused)
TEST(provenance, fused_copy_origin_tags)
{
set_provenance_enabled(true);
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto g = make_shared<op::Gelu>(p1);
g->add_provenance_tag("G");
auto r = make_shared<op::Result>(g);
auto f = make_shared<Function>(ResultVector{r}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
traverse_nodes(f, [&](const std::shared_ptr<Node>& node) {
auto tags = node->get_provenance_tags();
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("P1") != tags.end());
}
else if (node == r)
{
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"G"}));
EXPECT_TRUE(tags.find("G") != tags.end());
EXPECT_TRUE(tags.find("<Decomposed from Gelu>") != tags.end());
}
});
}
TEST(provenance, fused_decomposition_tag)
{
set_provenance_enabled(true);
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
auto fused_op = make_shared<op::MVN>(p1);
auto result = make_shared<op::Result>(fused_op);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
const auto tag = "<Decomposed from MVN>";
auto tag_check = [&tag](std::shared_ptr<ngraph::Node> node) {
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
const auto decomposed_op = f->get_result()->input(0).get_source_output().get_node_shared_ptr();
traverse_nodes(as_node_vector(decomposed_op->outputs()), tag_check, false, {p1});
}
TEST(provenance, topk_setk)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{20, 3, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment