Commit 9bfbd3c6 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Robert Kimball

[ONNX] Extended support for provenance tags (#4154)

* Checking if provenance_tags key exists

* Add provenance tag prototype

* Format provenance tag

* Display provenance tag

* Clean debug printing

* Add const to variables

* Separate method for add provenance tags

* Return NodeVector reference

* Return const NodeVector

* Moved add_provenance_tags function to commons

* Style apply

* Simple model for tests

* Provenance tag test

* Expect substring instead of  equal

* Add provenance tags to intermediate nodes recursively

* One tag per node

* Add traverse node args instead of recursion

* Return NodeVector instead of set of pointers

* Use treverse_nodes and lambda function

* Remove unused helper functions

* Remove is_constant() condition

* Update test model prototxt

* Update test substring

* Use node name and output names to build provenance tags in onnx importer

* Unit tests for onnx_importer provenance tags

* Missing <numeric> include

* Add provenance tags to constants buit from ONNX initializers

* Add provenance tags to Constants and Parameters created out of ONNX inputs and initializers

* More strict assertions in onnx provenance tests

* Unit test for onnx importer Parameter nodes tagging

* Helper function for the onnx provenance tests

* Some docs

* Obsolete comment removal

* Separate file for onnx provenance tags unit tests

* Code formatting

* Move the inputs tagging to the Graph class

* Tagging moved to the Graph class entirely

* Missing include and extra helper variable

* Unit tests helper documentation

* Change the UT helper to lowercase
Co-authored-by: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>
parent f08372ba
......@@ -15,6 +15,8 @@
//*****************************************************************************
#include <functional>
#include <numeric>
#include <sstream>
#include "graph.hpp"
#include "node.hpp"
......@@ -57,6 +59,37 @@ namespace ngraph
std::string domain = get_node_domain(node_proto);
return (domain.empty() ? "" : domain + ".") + node_proto.op_type();
}
static std::string concat_strings(
const std::vector<std::reference_wrapper<const std::string>>& strings)
{
const auto concat_with_comma =
[](const std::string& accumulator,
std::reference_wrapper<const std::string> next_string) {
return accumulator + ", " + next_string.get();
};
return std::accumulate(
strings.begin() + 1, strings.end(), strings.begin()->get(), concat_with_comma);
}
static std::string build_input_provenance_tag(const std::string& input_name,
const Shape& shape)
{
std::stringstream tag_builder;
tag_builder << "<ONNX Input (" << input_name << ") " << shape << ">";
return tag_builder.str();
}
static std::string build_op_provenance_tag(const Node& onnx_node)
{
const auto output_names = concat_strings(onnx_node.get_output_names());
const auto node_name =
onnx_node.get_name().empty() ? "" : onnx_node.get_name() + " ";
return std::string{"<ONNX " + onnx_node.op_type() + " (" + node_name + "-> " +
output_names + ")>"};
}
} // namespace detail
Graph::Graph(const onnx::GraphProto& graph_proto, Model& model, const Weights& weights)
......@@ -72,7 +105,9 @@ namespace ngraph
m_initializers.emplace(initializer_tensor.name(), tensor);
// For each initializer, create a Constant node and store in cache
m_ng_node_cache.emplace(initializer_tensor.name(), tensor.get_ng_constant());
auto ng_constant = tensor.get_ng_constant();
add_provenance_tag_to_initializer(tensor, ng_constant);
m_ng_node_cache.emplace(initializer_tensor.name(), std::move(ng_constant));
}
}
......@@ -87,8 +122,10 @@ namespace ngraph
continue;
}
m_ng_node_cache[input.name()] =
m_inputs.back().get_ng_node(m_parameters, m_initializers, weights);
const auto value_info = m_inputs.back();
auto ng_node = value_info.get_ng_node(m_parameters, m_initializers, weights);
add_provenance_tag_to_input(value_info, ng_node);
m_ng_node_cache[input.name()] = std::move(ng_node);
}
// Process all graph outputs
......@@ -160,11 +197,43 @@ namespace ngraph
{
const auto ng_node_factory =
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
const auto ng_node_vector = ng_node_factory(onnx_node);
common::add_provenance_tags(onnx_node, ng_node_vector);
add_provenance_tags(onnx_node, ng_node_vector);
return ng_node_vector;
}
void Graph::add_provenance_tag_to_initializer(
const Tensor& tensor, std::shared_ptr<default_opset::Constant> node) const
{
const std::string tag =
detail::build_input_provenance_tag(tensor.get_name(), tensor.get_shape());
node->add_provenance_tag(tag);
}
void Graph::add_provenance_tag_to_input(const ValueInfo& input,
std::shared_ptr<ngraph::Node> node) const
{
const std::string tag =
detail::build_input_provenance_tag(input.get_name(), input.get_shape());
node->add_provenance_tag(tag);
}
void Graph::add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector) const
{
const auto tag = detail::build_op_provenance_tag(onnx_node);
const auto ng_inputs = onnx_node.get_ng_inputs();
ngraph::traverse_nodes(
ng_node_vector,
[&tag](std::shared_ptr<ngraph::Node> ng_node) { ng_node->add_provenance_tag(tag); },
false,
ng_inputs);
}
} // namespace onnx_import
} // namespace ngraph
......@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include "default_opset.hpp"
#include "model.hpp"
#include "ngraph/op/parameter.hpp"
#include "operator_set.hpp"
......@@ -46,6 +47,15 @@ namespace ngraph
const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& onnx_node) const;
protected:
void add_provenance_tag_to_initializer(
const Tensor& initializer, std::shared_ptr<default_opset::Constant> node) const;
void add_provenance_tag_to_input(const ValueInfo& input,
std::shared_ptr<ngraph::Node> node) const;
void add_provenance_tags(const Node& onnx_node, const NodeVector& ng_node_vector) const;
private:
const onnx::GraphProto* m_graph_proto;
std::vector<Node> m_nodes;
......
......@@ -28,24 +28,6 @@ namespace ngraph
{
namespace common
{
const NodeVector& add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector)
{
const std::string node_name =
onnx_node.get_name().empty() ? "unnamed node" : onnx_node.get_name();
const std::string provenance_tag =
"<ONNX " + onnx_node.op_type() + " (" + node_name + ")>";
auto ng_inputs = onnx_node.get_ng_inputs();
ngraph::traverse_nodes(ng_node_vector,
[&](std::shared_ptr<ngraph::Node> ng_node) {
ng_node->add_provenance_tag(provenance_tag);
},
false,
ng_inputs);
return ng_node_vector;
}
const ngraph::element::Type& get_ngraph_element_type(int64_t onnx_type)
{
switch (onnx_type)
......
......@@ -38,8 +38,6 @@ namespace ngraph
{
namespace common
{
const NodeVector& add_provenance_tags(const Node& onnx_node,
const NodeVector& ng_node_vector);
const ngraph::element::Type& get_ngraph_element_type(std::int64_t onnx_type);
/// \brief Return a monotonic sequence.
......
......@@ -498,6 +498,7 @@ if (NGRAPH_ONNX_IMPORT_ENABLE)
list(APPEND MULTI_TEST_SRC
onnx/onnx_import.in.cpp
onnx/onnx_import_convpool.in.cpp
onnx/onnx_import_provenance.in.cpp
onnx/onnx_import_reshape.in.cpp
onnx/onnx_import_rnn.in.cpp
onnx/onnx_import_quant.in.cpp)
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "initializer_of_A"
input: "input_B"
output: "output_of_add"
op_type: "Add"
name: "Add_node"
}
name: "test_graph"
initializer {
dims: 0
data_type: 7
int64_data: 1
name: "initializer_of_A"
}
input {
name: "input_B"
type {
tensor_type {
elem_type: 7
shape {
dim {
}
}
}
}
}
output {
name: "output_of_add"
type {
tensor_type {
elem_type: 7
shape {
dim {
}
}
}
}
}
}
opset_import {
version: 9
}
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
name: "TOPK"
}
name: "test_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "input_A"
input: "input_B"
output: "output_of_add"
op_type: "Add"
name: "Add_node"
}
name: "test_graph"
input {
name: "input_A"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
input {
name: "input_B"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
output {
name: "output_of_add"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
}
opset_import {
version: 9
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "input_A"
input: "input_B"
output: "output_of_add"
op_type: "Add"
}
name: "test_graph"
input {
name: "input_A"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
input {
name: "input_B"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
output {
name: "output_of_add"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
}
opset_import {
version: 9
}
......@@ -356,21 +356,6 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_initializer_wo_input)
EXPECT_TRUE(test::all_close_f(expected_output, output.front()));
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_tag_add.prototxt"));
auto ng_nodes = function->get_ordered_ops();
for (auto ng_node : ng_nodes)
{
for (auto tag : ng_node->get_provenance_tags())
{
EXPECT_HAS_SUBSTRING(tag, "ONNX");
}
}
}
// ############################################################################ OPERATOR TESTS
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_addmul_abc)
{
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
using namespace ngraph;
using namespace ngraph::onnx_import;
static std::string s_manifest = "${MANIFEST}";
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_tag_add.prototxt"));
const auto ng_nodes = function->get_ordered_ops();
for (const auto ng_node : ng_nodes)
{
for (const auto tag : ng_node->get_provenance_tags())
{
EXPECT_HAS_SUBSTRING(tag, "ONNX");
}
}
}
// the NodeToCheck parameter of this template is used to find a node in the whole subgraph
// that a particular unit test is supposed to check against the expected provenance tag
template <typename NodeToCheck>
void test_provenance_tags(const std::string& model_path, const std::string& expected_provenance_tag)
{
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, model_path));
for (const auto ng_node : function->get_ordered_ops())
{
if (as_type_ptr<NodeToCheck>(ng_node))
{
const auto tags = ng_node->get_provenance_tags();
ASSERT_EQ(tags.size(), 1) << "There should be exactly one provenance tag set for "
<< ng_node;
EXPECT_EQ(*(tags.cbegin()), expected_provenance_tag);
}
}
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_only_output)
{
// the Add node in the model does not have a name,
// only its output name should be found in the provenance tags
test_provenance_tags<default_opset::Add>("onnx/provenance_only_outputs.prototxt",
"<ONNX Add (-> output_of_add)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_node_name_and_outputs)
{
test_provenance_tags<default_opset::Add>("onnx/provenance_node_name_and_outputs.prototxt",
"<ONNX Add (Add_node -> output_of_add)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_multiple_outputs_op)
{
test_provenance_tags<default_opset::TopK>("onnx/provenance_multiple_outputs_op.prototxt",
"<ONNX TopK (TOPK -> values, indices)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_constants)
{
test_provenance_tags<default_opset::Constant>("onnx/provenance_input_tags.prototxt",
"<ONNX Input (initializer_of_A) Shape{0}>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_parameters)
{
test_provenance_tags<default_opset::Parameter>("onnx/provenance_input_tags.prototxt",
"<ONNX Input (input_B) Shape{0}>");
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment