Commit 14a1645b authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Update friendly names to work properly in serialized graphs (#2566)

* update friendly names to work properly in serialized graphs

* style

* fix compile error
parent 05aa055c
......@@ -430,6 +430,12 @@ static shared_ptr<ngraph::Function>
try
{
string node_name = node_js.at("name").get<string>();
string friendly_name;
auto it = node_js.find("friendly_name");
if (it != node_js.end())
{
friendly_name = it->get<string>();
}
string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> control_deps_inputs =
......@@ -598,12 +604,13 @@ static shared_ptr<ngraph::Function>
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
try
auto value_it = node_js.find("value");
if (value_it != node_js.end())
{
auto value = node_js.at("value").get<vector<string>>();
auto value = value_it->get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
}
catch (...)
else
{
node = const_data_callback(node_name, element_type, shape);
}
......@@ -1193,17 +1200,21 @@ static shared_ptr<ngraph::Function>
node->add_control_dependency(node_map.at(name));
}
node->set_friendly_name(node_name);
if (!friendly_name.empty())
{
node->set_friendly_name(friendly_name);
}
node_map[node_name] = node;
}
catch (...)
{
string node_name;
try
auto it = node_js.find("name");
if (it != node_js.end())
{
node_name = node_js.at("name").get<string>();
node_name = it->get<string>();
}
catch (...)
else
{
node_name = "UNKNOWN";
}
......@@ -1251,7 +1262,11 @@ static shared_ptr<ngraph::Function>
static json write(const Node& n, bool binary_constant_data)
{
json node;
node["name"] = n.get_friendly_name();
node["name"] = n.get_name();
if (n.get_name() != n.get_friendly_name())
{
node["friendly_name"] = n.get_friendly_name();
}
node["op"] = n.description();
// TODO Multiple outputs
json inputs = json::array();
......
......@@ -91,6 +91,50 @@ TEST(serialize, main)
handle->call_with_validate({result}, {x, z, y});
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
}
TEST(serialize, friendly_name)
{
// First create "f(A,B,C) = (A+B)*C".
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto sum = A + B;
auto product = sum * C;
auto f = make_shared<Function>(product, ParameterVector{A, B, C}, "f");
A->set_friendly_name("A");
B->set_friendly_name("B");
C->set_friendly_name("C");
sum->set_friendly_name("Sum");
product->set_friendly_name("Product");
string js = serialize(f, 4);
ofstream out("serialize_function.js");
out << js;
istringstream in(js);
shared_ptr<Function> sfunc = deserialize(in);
auto backend = runtime::Backend::create("INTERPRETER");
auto handle = backend->compile(sfunc);
auto x = backend->create_tensor(element::f32, shape);
copy_data(x, vector<float>{1, 2, 3, 4});
auto y = backend->create_tensor(element::f32, shape);
copy_data(y, vector<float>{5, 6, 7, 8});
auto z = backend->create_tensor(element::f32, shape);
copy_data(z, vector<float>{9, 10, 11, 12});
auto result = backend->create_tensor(element::f32, shape);
handle->call_with_validate({result}, {x, y, z});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
handle->call_with_validate({result}, {y, x, z});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
handle->call_with_validate({result}, {x, z, y});
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
}
#endif
TEST(serialize, existing_models)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment