Commit 4ab4609f authored by Fabian Boemer's avatar Fabian Boemer Committed by Scott Cyphers

Added op annotations to cloned function (#3773)

parent 26590326
......@@ -282,6 +282,8 @@ std::list<std::shared_ptr<ngraph::Node>>
{
cloned_node->add_provenance_tag(tag);
}
cloned_node->set_op_annotations(node->get_op_annotations());
node_map[node.get()] = cloned_node;
}
}
......
......@@ -25,6 +25,7 @@
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/serializer.hpp"
......@@ -652,3 +653,35 @@ TEST(util, clone_function_friendly_name)
EXPECT_TRUE(found_A);
EXPECT_TRUE(found_B);
}
TEST(util, clone_function_op_annotations)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(A + B + C, ParameterVector{A, B, C});
auto cacheable_op_annotation = std::make_shared<op::util::OpAnnotations>();
cacheable_op_annotation->set_cacheable(true);
A->set_op_annotations(cacheable_op_annotation);
auto uncacheable_op_annotation = std::make_shared<op::util::OpAnnotations>();
uncacheable_op_annotation->set_cacheable(false);
B->set_op_annotations(uncacheable_op_annotation);
auto g = clone_function(*f);
bool found_A = false;
bool found_B = false;
for (auto parameter : g->get_parameters())
{
if (auto op_annotation = parameter->get_op_annotations())
{
found_A |= op_annotation->is_cacheable();
found_B |= !op_annotation->is_cacheable();
}
}
EXPECT_TRUE(found_A);
EXPECT_TRUE(found_B);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment