Commit 5b63a3c7 authored by tsocha's avatar tsocha Committed by Robert Kimball

[ONNX] Add null node. (#2386)

* [ONNX] Add null node.

Optional inputs in ONNX standard are represented by empty string.
We need a placeholder to keep information which inputs are not provided.

* Rename class null_node -> NullNode

* Remove unnecesary validate_and_infer_types method

* Add <memory> header

* Change name != "" -> !name.empty()

* Change constructor

* Little description

* Change node type in NullNode

* Add is_null() method

* Docstring

* Add UT

* Use override

* Style check

* Update null_node.cpp
parent 58b78ceb
...@@ -32,6 +32,8 @@ add_library(onnx_import STATIC ...@@ -32,6 +32,8 @@ add_library(onnx_import STATIC
core/model.cpp core/model.cpp
core/model.hpp core/model.hpp
core/node.hpp core/node.hpp
core/null_node.cpp
core/null_node.hpp
core/operator_set.hpp core/operator_set.hpp
core/tensor.hpp core/tensor.hpp
core/value_info.hpp core/value_info.hpp
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "attribute.hpp" #include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
#include "node.hpp" #include "node.hpp"
#include "null_node.hpp"
#include "tensor.hpp" #include "tensor.hpp"
namespace ngraph namespace ngraph
...@@ -122,7 +123,14 @@ namespace ngraph ...@@ -122,7 +123,14 @@ namespace ngraph
NodeVector result; NodeVector result;
for (const auto& name : m_node_proto->input()) for (const auto& name : m_node_proto->input())
{ {
result.push_back(m_graph->get_ng_node_from_cache(name)); if (!name.empty())
{
result.push_back(m_graph->get_ng_node_from_cache(name));
}
else
{
result.push_back(std::make_shared<NullNode>());
}
} }
return result; return result;
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "null_node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
NullNode::NullNode()
: ngraph::Node("NullNode", {}, 0)
{
}
std::shared_ptr<Node> NullNode::copy_with_new_args(const NodeVector& new_args) const
{
return std::make_shared<NullNode>();
}
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
/// \brief Represents a missing optional input or output of an ONNX node
///
/// Some ONNX operators have inputs or outputs that are marked as optional,
/// which means that a referring node MAY forgo providing values for such inputs
/// or computing these outputs.
/// An empty string is used in place of a name of such input or output.
///
/// More:
/// https://github.com/onnx/onnx/blob/master/docs/IR.md#optional-inputs-and-outputs
class NullNode : public ngraph::Node
{
public:
NullNode();
bool is_null() const final override { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace onnx_import
} // namespace ngraph
...@@ -46,7 +46,7 @@ namespace ngraph ...@@ -46,7 +46,7 @@ namespace ngraph
{ {
class Parameter; class Parameter;
class Result; class Result;
} } // namespace op
void replace_node_users_arguments(std::shared_ptr<Node> target, void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement); std::shared_ptr<Node> replacement);
...@@ -129,6 +129,7 @@ namespace ngraph ...@@ -129,6 +129,7 @@ namespace ngraph
bool is_parameter() const; bool is_parameter() const;
virtual bool is_output() const; virtual bool is_output() const;
virtual bool is_constant() const; virtual bool is_constant() const;
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
...@@ -291,7 +292,7 @@ namespace ngraph ...@@ -291,7 +292,7 @@ namespace ngraph
}; };
void check_new_args_count(const Node* node, const NodeVector& new_args); void check_new_args_count(const Node* node, const NodeVector& new_args);
} } // namespace ngraph
#define NODE_VALIDATION_ASSERT(node, cond) \ #define NODE_VALIDATION_ASSERT(node, cond) \
NGRAPH_ASSERT_STREAM_WITH_LOC( \ NGRAPH_ASSERT_STREAM_WITH_LOC( \
......
...@@ -1864,6 +1864,53 @@ TEST(onnx_${BACKEND_NAME}, model_top_k) ...@@ -1864,6 +1864,53 @@ TEST(onnx_${BACKEND_NAME}, model_top_k)
EXPECT_TRUE(test::all_close(expected_indices_output, indices_output)); EXPECT_TRUE(test::all_close(expected_indices_output, indices_output));
} }
TEST(onnx_${BACKEND_NAME}, model_missing_input)
{
onnx_import::register_operator(
"TestMissingInOut", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> A = ng_inputs.at(0);
std::shared_ptr<ngraph::Node> B = ng_inputs.at(1);
std::shared_ptr<ngraph::Node> C = ng_inputs.at(2);
A = A * C;
if (!B->is_null())
{
B = B / C;
}
C = C + C;
return {A, B, C};
});
onnx_import::register_operator(
"TestMissingIn", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> result = std::make_shared<ngraph::op::Constant>(
element::f32, ngraph::Shape{2, 2}, std::vector<float>{1, 1, 1, 1});
for (const auto& ng_input : ng_inputs)
{
if (!ng_input->is_null())
{
result = ng_input * result;
}
}
return {result};
});
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/missing_input.onnx"));
Inputs inputs{{1, 2, 3, 4}, {5, 6, 7, 8}};
Outputs expected_outputs{{50, 144, 294, 512}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_sinh) TEST(onnx_${BACKEND_NAME}, model_sinh)
{ {
auto function = auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment