Commit fbd99f34 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by Scott Cyphers

Add provenance tags to ONNX importer (#4108)

* Checking if provenance_tags key exists

* Add provenance tag prototype

* Format provenance tag

* Display provenance tag

* Clean debug printing

* Add const to variables

* Separate method for add provenance tags

* Return NodeVector reference

* Return const NodeVector

* Moved add_provenance_tags function to commons

* Style apply

* Simple model for tests

* Provenance tag test

* Expect substring instead of  equal

* Add provenance tags to intermediate nodes recursively

* One tag per node

* Add traverse node args instead of recursion

* Return NodeVector instead of set of pointers

* Use treverse_nodes and lambda function

* Remove unused helper functions

* Remove is_constant() condition

* Update test model prototxt

* Update test substring

* Make code slightly more readable
Co-authored-by: 's avatarTomasz Dołbniak <tomasz.dolbniak@intel.com>
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 4204096d
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "graph.hpp" #include "graph.hpp"
#include "node.hpp" #include "node.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -155,6 +156,15 @@ namespace ngraph ...@@ -155,6 +156,15 @@ namespace ngraph
return results; return results;
} }
NodeVector Graph::make_ng_nodes(const Node& onnx_node) const
{
const auto ng_node_factory =
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
const auto ng_node_vector = ng_node_factory(onnx_node);
common::add_provenance_tags(onnx_node, ng_node_vector);
return ng_node_vector;
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -34,7 +34,6 @@ namespace ngraph ...@@ -34,7 +34,6 @@ namespace ngraph
{ {
public: public:
Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {}); Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {});
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; } const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
...@@ -44,12 +43,8 @@ namespace ngraph ...@@ -44,12 +43,8 @@ namespace ngraph
{ {
return m_ng_node_cache.at(name); return m_ng_node_cache.at(name);
} }
const std::string& get_name() const { return m_graph_proto->name(); } const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const NodeVector make_ng_nodes(const Node& onnx_node) const;
{
return m_model->get_operator(node.op_type(), node.domain())(node);
}
private: private:
const onnx::GraphProto* m_graph_proto; const onnx::GraphProto* m_graph_proto;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "common.hpp" #include "common.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp" #include "ngraph/opsets/opset0.hpp"
#include "validation_util.hpp" #include "validation_util.hpp"
...@@ -27,6 +28,24 @@ namespace ngraph ...@@ -27,6 +28,24 @@ namespace ngraph
{ {
namespace common namespace common
{ {
const NodeVector& add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector)
{
const std::string node_name =
onnx_node.get_name().empty() ? "unnamed node" : onnx_node.get_name();
const std::string provenance_tag =
"<ONNX " + onnx_node.op_type() + " (" + node_name + ")>";
auto ng_inputs = onnx_node.get_ng_inputs();
ngraph::traverse_nodes(ng_node_vector,
[&](std::shared_ptr<ngraph::Node> ng_node) {
ng_node->add_provenance_tag(provenance_tag);
},
false,
ng_inputs);
return ng_node_vector;
}
const ngraph::element::Type& get_ngraph_element_type(int64_t onnx_type) const ngraph::element::Type& get_ngraph_element_type(int64_t onnx_type)
{ {
switch (onnx_type) switch (onnx_type)
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "default_opset.hpp" #include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -37,6 +38,8 @@ namespace ngraph ...@@ -37,6 +38,8 @@ namespace ngraph
{ {
namespace common namespace common
{ {
const NodeVector& add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector);
const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type); const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type);
/// \brief Return a monotonic sequence. /// \brief Return a monotonic sequence.
......
...@@ -2981,10 +2981,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2981,10 +2981,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
if (ngraph::get_provenance_enabled()) if (ngraph::get_provenance_enabled())
{ {
std::vector<json> prov_js = node_js.at("provenance_tags"); if (has_key(node_js, "provenance_tags"))
for (auto prov_tag : prov_js)
{ {
node->add_provenance_tag(prov_tag); const std::vector<json> prov_js = node_js.at("provenance_tags");
for (auto prov_tag : prov_js)
{
node->add_provenance_tag(prov_tag);
}
} }
} }
m_node_map[node_name] = node; m_node_map[node_name] = node;
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "X"
name: "add_node1"
op_type: "Add"
}
node {
input: "X"
input: "C"
output: "Y"
name: "add_node2"
op_type: "Add"
}
name: "test_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "C"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 4
}
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "util/test_case.hpp" #include "util/test_case.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
#include "util/type_prop.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -355,6 +356,21 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_initializer_wo_input) ...@@ -355,6 +356,21 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_initializer_wo_input)
EXPECT_TRUE(test::all_close_f(expected_output, output.front())); EXPECT_TRUE(test::all_close_f(expected_output, output.front()));
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_tag_add.prototxt"));
auto ng_nodes = function->get_ordered_ops();
for (auto ng_node : ng_nodes)
{
for (auto tag : ng_node->get_provenance_tags())
{
EXPECT_HAS_SUBSTRING(tag, "ONNX");
}
}
}
// ############################################################################ OPERATOR TESTS // ############################################################################ OPERATOR TESTS
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_addmul_abc) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_addmul_abc)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment