Commit 12b5f085 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Sang Ik Lee

[ONNX] Handle trimmed optional outputs. (#2434)

* Function for retrieving number of node outputs.

* Handle optional trimmed outputs.

* Fix compilation err on clang.

* Fix error for number of outputs.

- Iterate over the minimum of number of outputs we return and the number
  of outputs of respective node in the graph. Some outputs may be
  optional and trimmed, as well as for some op implementations we may
  return not all outputs (ie. Dropout - where we do not return additional
  optional output).

* Update graph.cpp

* Add dropout ONNX op.

* Revert to iterate over node outputs in graph.

* Use more apropriate word in comment.
parent 718e2ef1
......@@ -75,6 +75,7 @@ add_library(onnx_import STATIC
op/depth_to_space.cpp
op/depth_to_space.hpp
op/div.hpp
op/dropout.hpp
op/elu.cpp
op/elu.hpp
op/equal.hpp
......
......@@ -91,6 +91,7 @@ namespace ngraph
m_inputs.back().get_ng_node(m_parameters, m_initializers, weights);
}
// Process all graph outputs
for (const auto& output : m_graph_proto->output())
{
m_outputs.emplace_back(output);
......@@ -104,7 +105,7 @@ namespace ngraph
{
unknown_operators.emplace(detail::get_op_domain_and_name(node_proto),
node_proto);
// Try adding missing domain
// If a node from an unregistered domain is detected, try registering that domain
m_model->enable_opset_domain(detail::get_node_domain(node_proto));
}
}
......@@ -132,8 +133,12 @@ namespace ngraph
{
m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()};
NodeVector ng_nodes{node.get_ng_nodes()};
for (int i = 0; i < ng_nodes.size(); i++)
// Iterate over the number of outputs for given node in graph.
// Some of them may be optional and trimmed. See:
// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
for (std::size_t i{0}; i < node.get_outputs_size(); ++i)
{
m_ng_node_cache[node.output(i)] = ng_nodes[i];
}
......
......@@ -50,6 +50,7 @@ namespace ngraph
const std::string& description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
std::size_t get_outputs_size() const;
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;
......@@ -85,6 +86,7 @@ namespace ngraph
return m_node_proto->output(index);
}
std::size_t Node::Impl::get_outputs_size() const { return m_output_names.size(); }
template <typename T>
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
{
......@@ -182,6 +184,7 @@ namespace ngraph
}
const std::string& Node::output(int index) const { return m_pimpl->output(index); }
std::size_t Node::get_outputs_size() const { return m_pimpl->get_outputs_size(); }
template <>
float Node::get_attribute_value(const std::string& name, float default_value) const
{
......
......@@ -16,6 +16,7 @@
#pragma once
#include <cstddef>
#include <string>
#include "ngraph/except.hpp"
......@@ -75,6 +76,7 @@ namespace ngraph
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
std::size_t get_outputs_size() const;
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "core/null_node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector dropout(const Node& node)
{
// First value is actual output of Dropout,
// the second one is just a placeholder for optional trailing output.
return {node.get_ng_inputs().at(0), std::make_shared<NullNode>()};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -46,6 +46,7 @@
#include "op/cosh.hpp"
#include "op/depth_to_space.hpp"
#include "op/div.hpp"
#include "op/dropout.hpp"
#include "op/elu.hpp"
#include "op/equal.hpp"
#include "op/exp.hpp"
......@@ -239,7 +240,7 @@ namespace ngraph
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Div", 7, div);
REGISTER_OPERATOR("Dropout", 1, identity);
REGISTER_OPERATOR("Dropout", 1, dropout);
REGISTER_OPERATOR("Elu", 1, elu);
REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Exp", 1, exp);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment