Commit 2d9a974c authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[ONNX] Support for initializers without corresponding inputs (#2406)

* [ONNX] Move make_ng_constant to Tensor class

* [ONNX] Add workaround for initializers without corresponding inputs

* Fix malformed test models

* clang-format

* Avoid creating multiple constants

* Code review comments
parent 258d6730
...@@ -63,11 +63,16 @@ namespace ngraph ...@@ -63,11 +63,16 @@ namespace ngraph
: m_graph_proto{&graph_proto} : m_graph_proto{&graph_proto}
, m_model{&model} , m_model{&model}
{ {
for (const auto& tensor : m_graph_proto->initializer()) // Process all initializers in the graph
for (const auto& initializer_tensor : m_graph_proto->initializer())
{ {
if (tensor.has_name()) if (initializer_tensor.has_name())
{ {
m_initializers.emplace(tensor.name(), Tensor{tensor}); Tensor tensor = Tensor{initializer_tensor};
m_initializers.emplace(initializer_tensor.name(), tensor);
// For each initializer, create a Constant node and store in cache
m_ng_node_cache.emplace(initializer_tensor.name(), tensor.get_ng_constant());
} }
} }
...@@ -75,6 +80,13 @@ namespace ngraph ...@@ -75,6 +80,13 @@ namespace ngraph
for (const auto& input : m_graph_proto->input()) for (const auto& input : m_graph_proto->input())
{ {
m_inputs.emplace_back(input); m_inputs.emplace_back(input);
// Check if a Constant node was already created from an initializer
if (m_ng_node_cache.count(input.name()) > 0)
{
continue;
}
m_ng_node_cache[input.name()] = m_ng_node_cache[input.name()] =
m_inputs.back().get_ng_node(m_parameters, m_initializers, weights); m_inputs.back().get_ng_node(m_parameters, m_initializers, weights);
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -387,7 +388,44 @@ namespace ngraph ...@@ -387,7 +388,44 @@ namespace ngraph
} }
operator TensorProto_DataType() const { return m_tensor_proto->data_type(); } operator TensorProto_DataType() const { return m_tensor_proto->data_type(); }
std::shared_ptr<ngraph::op::Constant> get_ng_constant() const
{
switch (m_tensor_proto->data_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_ng_constant<float>(element::f32);
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return make_ng_constant<double>(element::f64);
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
return make_ng_constant<int8_t>(element::i8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
return make_ng_constant<int16_t>(element::i16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
return make_ng_constant<int32_t>(element::i32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
return make_ng_constant<int64_t>(element::i64);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return make_ng_constant<uint8_t>(element::u8);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return make_ng_constant<uint16_t>(element::u16);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return make_ng_constant<uint32_t>(element::u32);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return make_ng_constant<uint64_t>(element::u64);
default: throw error::tensor::unsupported_data_type{m_tensor_proto->data_type()};
}
}
private: private:
template <typename T>
std::shared_ptr<ngraph::op::Constant> make_ng_constant(const element::Type& type) const
{
return std::make_shared<ngraph::op::Constant>(type, m_shape, get_data<T>());
}
const onnx::TensorProto* m_tensor_proto; const onnx::TensorProto* m_tensor_proto;
Shape m_shape; Shape m_shape;
}; };
......
...@@ -138,42 +138,7 @@ namespace ngraph ...@@ -138,42 +138,7 @@ namespace ngraph
std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const
{ {
switch (m_value_info_proto->type().tensor_type().elem_type()) return tensor.get_ng_constant();
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_ng_constant<float>(element::f32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return make_ng_constant<double>(element::f64, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
return make_ng_constant<int8_t>(element::i8, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
return make_ng_constant<int16_t>(element::i16, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
return make_ng_constant<int32_t>(element::i32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
return make_ng_constant<int64_t>(element::i64, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return make_ng_constant<uint8_t>(element::u8, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return make_ng_constant<uint16_t>(element::u16, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return make_ng_constant<uint32_t>(element::u32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return make_ng_constant<uint64_t>(element::u64, tensor);
default:
throw error::value_info::unsupported_element_type{
m_value_info_proto->type().tensor_type().elem_type()};
}
}
template <typename T>
std::shared_ptr<op::Constant> make_ng_constant(const element::Type& type,
const Tensor& tensor) const
{
return std::make_shared<op::Constant>(type, m_shape, tensor.get_data<T>());
} }
private: private:
......
...@@ -1889,3 +1889,18 @@ TEST(onnx_${BACKEND_NAME}, model_cosh) ...@@ -1889,3 +1889,18 @@ TEST(onnx_${BACKEND_NAME}, model_cosh)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
} }
TEST(onnx_${BACKEND_NAME}, model_initializer_wo_input)
{
// This test checks a model which has an initializer, but no input with the same name
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/initializer_wo_input.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0, 1, 2, 3, 4, 5});
std::vector<float> expected_output{0, 2, 6, 12, 20, 30};
Outputs output{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output, output.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment