Commit 41e1182f authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add copy of friendly name to ngraph::copy_function (#3047)

* Copy friendly name when copying node

* add unit test

* style
parent eb43fcc5
......@@ -249,11 +249,17 @@ std::list<std::shared_ptr<ngraph::Node>>
}
auto cloned_node = node->copy_with_new_args(cloned_args);
//copy control dependencies
// copy control dependencies
for (auto cdep : node->get_control_dependencies())
{
cloned_node->add_control_dependency(node_map.at(cdep.get()));
}
if (node->get_friendly_name() != node->get_name())
{
// There is a friendly name for this node so copy it
cloned_node->set_friendly_name(node->get_friendly_name());
}
node_map[node.get()] = cloned_node;
}
}
......
......@@ -634,3 +634,26 @@ TEST(util, apply_permutation_pshape_rank_dynamic_inviable_permutation_fails)
{
ASSERT_THROW(apply_permutation(PartialShape::dynamic(), AxisVector{0, 1, 2, 2}), CheckFailure);
}
TEST(util, clone_function_friendly_name)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Add>(A, B), ParameterVector{A, B});
A->set_friendly_name("A");
B->set_friendly_name("B");
auto g = clone_function(*f);
bool found_A = false;
bool found_B = false;
for (auto parameter : g->get_parameters())
{
found_A |= parameter->get_friendly_name() == "A";
found_B |= parameter->get_friendly_name() == "B";
}
EXPECT_TRUE(found_A);
EXPECT_TRUE(found_B);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment