Unverified Commit 88bcb685 authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Add friendly names to nGraph nodes generated from ONNX (#4357)

* Add friendly names to nGraph ONNX nodes generated

* Review comments
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 9ae0a564
......@@ -199,11 +199,28 @@ namespace ngraph
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
const auto ng_node_vector = ng_node_factory(onnx_node);
set_friendly_names(onnx_node, ng_node_vector);
add_provenance_tags(onnx_node, ng_node_vector);
return ng_node_vector;
}
void Graph::set_friendly_names(const Node& onnx_node,
const NodeVector& ng_node_vector) const
{
for (int i = 0; i < ng_node_vector.size(); ++i)
{
// Trailing optional outputs may not be specified in the ONNX model.
// Other optional outputs should have name set to an empty string.
if (i >= onnx_node.get_outputs_size())
{
break;
}
ng_node_vector[i]->set_friendly_name(onnx_node.output(i));
}
}
void Graph::add_provenance_tag_to_initializer(
const Tensor& tensor, std::shared_ptr<default_opset::Constant> node) const
{
......
......@@ -47,6 +47,8 @@ namespace ngraph
NodeVector make_ng_nodes(const Node& onnx_node) const;
protected:
void set_friendly_names(const Node& onnx_node, const NodeVector& ng_node_vector) const;
void add_provenance_tag_to_initializer(
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;
......
......@@ -90,6 +90,25 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, output_names_check)
}
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, node_names_check)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/add_abc.prototxt"));
// Filter out Add nodes from the function graph
std::vector<std::shared_ptr<Node>> additions;
auto ordered_ops = function->get_ordered_ops();
std::copy_if(
ordered_ops.begin(),
ordered_ops.end(),
std::back_inserter(additions),
[](std::shared_ptr<Node> op) { return std::string(op->get_type_name()) == "Add"; });
EXPECT_EQ(additions.size(), 2);
EXPECT_EQ(additions.at(0)->get_friendly_name(), "X");
EXPECT_EQ(additions.at(1)->get_friendly_name(), "Y");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_add_abc)
{
auto function = onnx_import::import_onnx_model(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment