Commit c086eb2d authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

onnx [2]: add core wrappers (#1169)

* onnx: add core wrappers
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: add '\n' at end of files
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix compilation with clang
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix code style
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent d05b5e39
......@@ -31,7 +31,18 @@ add_library(onnx_import_interface OBJECT
onnx.proto)
add_library(onnx_import STATIC
onnx.pb.cc)
onnx.pb.cc
attribute.cpp
attribute.hpp
graph.cpp
graph.hpp
model.hpp
node.cpp
node.hpp
ops_bridge.cpp
ops_bridge.hpp
tensor.hpp
value_info.hpp)
add_dependencies(onnx_import onnx_import_interface)
......@@ -46,5 +57,8 @@ if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
-Wno-extended-offsetof -Wno-zero-as-null-pointer-constant -Wno-shorten-64-to-32 -Wno-unused-macros
-Wno-missing-variable-declarations -Wno-unused-private-field)
target_compile_options(onnx_import_interface PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
-Wno-extended-offsetof -Wno-zero-as-null-pointer-constant -Wno-shorten-64-to-32 -Wno-unused-macros
-Wno-missing-variable-declarations -Wno-unused-private-field)
endif()
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "attribute.hpp"
#include "graph.hpp"
namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array() const
{
return {std::begin(m_attribute_proto.graphs()), std::end(m_attribute_proto.graphs())};
}
Graph Attribute::get_graph() const { return Graph{m_attribute_proto.g()}; }
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/except.hpp"
#include "onnx.pb.h"
#include "tensor.hpp"
namespace ngraph
{
namespace onnx_import
{
// forward declaration
class Graph;
namespace error
{
namespace attribute
{
namespace detail
{
struct attribute : ngraph_error
{
attribute(std::string msg, onnx::AttributeProto_AttributeType type)
: ngraph_error{std::move(msg) + ": " +
onnx::AttributeProto_AttributeType_Name(type)}
{
}
};
} // namespace detail
struct invalid_data : detail::attribute
{
explicit invalid_data(onnx::AttributeProto_AttributeType type)
: attribute{"invalid attribute type", type}
{
}
};
struct unsupported_type : detail::attribute
{
explicit unsupported_type(onnx::AttributeProto_AttributeType type)
: attribute("unsupported attribute type", type)
{
}
};
} // namespace attribute
} // namespace error
class Attribute
{
public:
enum class Type
{
undefined = onnx::AttributeProto_AttributeType_UNDEFINED,
float_point = onnx::AttributeProto_AttributeType_FLOAT,
integer = onnx::AttributeProto_AttributeType_INT,
string = onnx::AttributeProto_AttributeType_STRING,
tensor = onnx::AttributeProto_AttributeType_TENSOR,
graph = onnx::AttributeProto_AttributeType_GRAPH,
float_point_array = onnx::AttributeProto_AttributeType_FLOATS,
integer_array = onnx::AttributeProto_AttributeType_INTS,
string_array = onnx::AttributeProto_AttributeType_STRINGS,
tensor_array = onnx::AttributeProto_AttributeType_TENSORS,
graph_array = onnx::AttributeProto_AttributeType_GRAPHS
};
Attribute() = delete;
explicit Attribute(const onnx::AttributeProto& attribute_proto)
: m_attribute_proto{attribute_proto}
{
}
Attribute(Attribute&&) noexcept = default;
Attribute(const Attribute&) = default;
Attribute& operator=(Attribute&&) noexcept = delete;
Attribute& operator=(const Attribute&) = delete;
const std::string& get_name() const { return m_attribute_proto.name(); }
Type get_type() const { return static_cast<Type>(m_attribute_proto.type()); }
bool is_tensor() const { return get_type() == Type::tensor; }
bool is_tensor_array() const { return get_type() == Type::tensor_array; }
bool is_float() const { return get_type() == Type::float_point; }
bool is_float_array() const { return get_type() == Type::float_point_array; }
bool is_integer() const { return get_type() == Type::integer; }
bool is_integer_array() const { return get_type() == Type::integer_array; }
bool is_string() const { return get_type() == Type::string; }
bool is_string_array() const { return get_type() == Type::string_array; }
bool is_graph() const { return get_type() == Type::graph; }
bool is_graph_array() const { return get_type() == Type::graph_array; }
Tensor get_tensor() const { return Tensor{m_attribute_proto.t()}; }
float get_float() const { return m_attribute_proto.f(); }
int64_t get_integer() const { return m_attribute_proto.i(); }
const std::string& get_string() const { return m_attribute_proto.s(); }
Graph get_graph() const;
std::vector<Tensor> get_tensor_array() const
{
return {std::begin(m_attribute_proto.tensors()),
std::end(m_attribute_proto.tensors())};
}
std::vector<float> get_float_array() const
{
return {std::begin(m_attribute_proto.floats()),
std::end(m_attribute_proto.floats())};
}
std::vector<int64_t> get_integer_array() const
{
return {std::begin(m_attribute_proto.ints()), std::end(m_attribute_proto.ints())};
}
std::vector<std::string> get_string_array() const
{
return {std::begin(m_attribute_proto.strings()),
std::end(m_attribute_proto.strings())};
}
std::vector<Graph> get_graph_array() const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
return m_attribute_proto.type();
}
private:
const onnx::AttributeProto& m_attribute_proto;
};
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "graph.hpp"
#include "node.hpp"
namespace ngraph
{
namespace onnx_import
{
Graph::Graph(const onnx::GraphProto& graph_proto)
: m_graph_proto(graph_proto)
{
for (const auto& tensor : m_graph_proto.initializer())
{
if (tensor.has_name())
{
m_initializers.emplace(tensor.name(), Tensor{tensor});
}
}
// Process all ONNX graph inputs, convert them to nGraph nodes and store in cache
for (const auto& input : m_graph_proto.input())
{
m_inputs.emplace_back(input);
m_ng_node_cache[input.name()] =
m_inputs.back().get_ng_node(m_parameters, m_initializers);
}
for (const auto& output : m_graph_proto.output())
m_outputs.emplace_back(output);
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto.node())
{
m_nodes.emplace_back(node_proto, this);
const Node& node{m_nodes.back()};
const auto& ng_nodes{node.get_ng_nodes()};
for (int i = 0; i < ng_nodes.size(); i++)
{
m_ng_node_cache[node.output(i)] = ng_nodes[i];
}
}
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <map>
#include <ostream>
#include <string>
#include <vector>
#include "ngraph/op/parameter_vector.hpp"
#include "onnx.pb.h"
#include "value_info.hpp"
namespace ngraph
{
namespace onnx_import
{
class Graph
{
public:
explicit Graph(const onnx::GraphProto& proto);
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
const std::vector<ValueInfo>& get_outputs() const { return m_outputs; }
const op::ParameterVector& get_ng_parameters() const { return m_parameters; }
std::shared_ptr<ngraph::Node> get_ng_node_from_cache(const std::string& name) const
{
return m_ng_node_cache.at(name);
}
const std::string& get_name() const { return m_graph_proto.name(); }
private:
const onnx::GraphProto& m_graph_proto;
std::vector<Node> m_nodes;
std::vector<ValueInfo> m_inputs;
std::vector<ValueInfo> m_outputs;
op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
{
return (outs << "<Graph: " << graph.get_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <ostream>
#include "onnx.pb.h"
namespace ngraph
{
namespace onnx_import
{
class Model
{
public:
Model() = delete;
explicit Model(const onnx::ModelProto& model_proto)
: m_model_proto{model_proto}
{
}
Model(Model&&) noexcept = default;
Model(const Model&) = default;
Model& operator=(Model&&) noexcept = delete;
Model& operator=(const Model&) = delete;
const std::string& get_producer_name() const { return m_model_proto.producer_name(); }
const onnx::GraphProto& get_graph() const { return m_model_proto.graph(); }
std::int64_t get_model_version() const { return m_model_proto.model_version(); }
const std::string& get_producer_version() const
{
return m_model_proto.producer_version();
}
private:
const onnx::ModelProto& m_model_proto;
};
inline std::ostream& operator<<(std::ostream& outs, const Model& model)
{
return (outs << "<Model: " << model.get_producer_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "node.hpp"
#include "graph.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
NodeVector Node::get_ng_nodes() const { return ops_bridge::make_ng_nodes(*this); }
NodeVector Node::get_ng_inputs() const
{
NodeVector result;
for (const auto& name : m_node_proto.input())
{
result.push_back(m_graph->get_ng_node_from_cache(name));
}
return result;
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <ostream>
#include <string>
#include "ngraph/node_vector.hpp"
#include "onnx.pb.h"
#include "attribute.hpp"
#include "tensor.hpp"
namespace ngraph
{
namespace onnx_import
{
class Graph;
class Node
{
public:
Node() = delete;
Node(const onnx::NodeProto& node_proto, const Graph* graph)
: m_node_proto{node_proto}
, m_graph{graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
{
}
Node(Node&&) noexcept = default;
Node(const Node&) = default;
Node& operator=(Node&&) noexcept = delete;
Node& operator=(const Node&) = delete;
const std::vector<Attribute>& attributes() const { return m_attributes; }
NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const;
const std::string& op_type() const { return m_node_proto.op_type(); }
const std::string& get_name() const { return m_node_proto.name(); }
const std::string& output(int index) const { return m_node_proto.output(index); }
private:
const onnx::NodeProto& m_node_proto;
const Graph* m_graph;
std::vector<Attribute> m_attributes;
};
inline std::ostream& operator<<(std::ostream& outs, const Node& node)
{
return (outs << "<Node(" << node.op_type() << "): " << node.get_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <functional>
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace detail
{
namespace error
{
struct unknown_operation : ngraph_error
{
explicit unknown_operation(const std::string& op_type)
: ngraph_error{"unknown operation: " + op_type}
{
}
};
} // namespace error
class ops_bridge
{
public:
ops_bridge(const ops_bridge&) = delete;
ops_bridge& operator=(const ops_bridge&) = delete;
ops_bridge(ops_bridge&&) = delete;
ops_bridge& operator=(ops_bridge&&) = delete;
static NodeVector make_ng_nodes(const Node& node)
{
return ops_bridge::get()(node);
}
private:
std::map<std::string, std::function<NodeVector(const Node&)>> m_map;
static const ops_bridge& get()
{
static ops_bridge instance;
return instance;
}
ops_bridge() {}
NodeVector operator()(const Node& node) const
{
try
{
return m_map.at(node.op_type())(node);
}
catch (const std::exception&)
{
throw detail::error::unknown_operation{node.op_type()};
}
}
};
} // namespace detail
namespace ops_bridge
{
NodeVector make_ng_nodes(const Node& node)
{
return detail::ops_bridge::make_ng_nodes(node);
}
} // namespace ops_bridge
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node_vector.hpp"
#include "node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace ops_bridge
{
NodeVector make_ng_nodes(const onnx_import::Node&);
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx.pb.h"
namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace tensor
{
struct invalid_data_type : ngraph_error
{
explicit invalid_data_type(onnx::TensorProto_DataType type)
: ngraph_error{"invalid data type: " +
onnx::TensorProto_DataType_Name(type)}
{
}
};
struct unsupported_data_type : ngraph_error
{
explicit unsupported_data_type(onnx::TensorProto_DataType type)
: ngraph_error{"unsupported data type: " +
onnx::TensorProto_DataType_Name(type)}
{
}
};
struct unspecified_name : ngraph_error
{
unspecified_name()
: ngraph_error{"tensor has no name specified"}
{
}
};
struct unspecified_data_type : ngraph_error
{
unspecified_data_type()
: ngraph_error{"tensor has no data type specified"}
{
}
};
} // namespace tensor
} // namespace error
namespace detail
{
namespace tensor
{
namespace
{
namespace detail
{
template <typename T, typename Container>
inline std::vector<T> __get_data(const Container& container)
{
return {std::begin(container), std::end(container)};
}
}
}
template <typename T>
inline std::vector<T> get_data(const onnx::TensorProto& tensor)
{
throw error::tensor::unsupported_data_type{tensor.data_type()};
}
template <>
inline std::vector<double> get_data(const onnx::TensorProto& tensor)
{
if (tensor.data_type() == onnx::TensorProto_DataType_DOUBLE)
{
return detail::__get_data<double>(tensor.double_data());
}
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) or
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
{
return detail::__get_data<double>(tensor.float_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{
return detail::__get_data<double>(tensor.int32_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT64)
{
return detail::__get_data<double>(tensor.int64_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT64)
{
return detail::__get_data<double>(tensor.uint64_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<float> get_data(const onnx::TensorProto& tensor)
{
if ((tensor.data_type() == onnx::TensorProto_DataType_FLOAT) or
(tensor.data_type() == onnx::TensorProto_DataType_FLOAT16))
{
return detail::__get_data<float>(tensor.float_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{
return detail::__get_data<float>(tensor.int32_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_INT64)
{
return detail::__get_data<float>(tensor.int64_data());
}
if (tensor.data_type() == onnx::TensorProto_DataType_UINT64)
{
return detail::__get_data<float>(tensor.uint64_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int32_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.data_type() == onnx::TensorProto_DataType_INT32)
{
return detail::__get_data<int32_t>(tensor.int32_data());
}
throw error::tensor::invalid_data_type{tensor.data_type()};
}
template <>
inline std::vector<int64_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.data_type() != onnx::TensorProto_DataType_INT64)
{
throw error::tensor::invalid_data_type{tensor.data_type()};
}
return detail::__get_data<int64_t>(tensor.int64_data());
}
template <>
inline std::vector<uint64_t> get_data(const onnx::TensorProto& tensor)
{
if (tensor.data_type() != onnx::TensorProto_DataType_UINT64)
{
throw error::tensor::invalid_data_type{tensor.data_type()};
}
return detail::__get_data<uint64_t>(tensor.uint64_data());
}
}
}
class Tensor
{
public:
enum class Type
{
undefined = onnx::TensorProto_DataType_UNDEFINED,
float32 = onnx::TensorProto_DataType_FLOAT,
uint8 = onnx::TensorProto_DataType_UINT8,
int8 = onnx::TensorProto_DataType_INT8,
uint16 = onnx::TensorProto_DataType_UINT16,
int16 = onnx::TensorProto_DataType_INT16,
int32 = onnx::TensorProto_DataType_INT32,
int64 = onnx::TensorProto_DataType_INT64,
string = onnx::TensorProto_DataType_STRING,
boolean = onnx::TensorProto_DataType_BOOL,
float16 = onnx::TensorProto_DataType_FLOAT16,
float64 = onnx::TensorProto_DataType_DOUBLE,
uint32 = onnx::TensorProto_DataType_UINT32,
uint64 = onnx::TensorProto_DataType_UINT64,
complex64 = onnx::TensorProto_DataType_COMPLEX64,
complex128 = onnx::TensorProto_DataType_COMPLEX128
};
Tensor() = delete;
explicit Tensor(const onnx::TensorProto& tensor)
: m_tensor_proto{tensor}
, m_shape{std::begin(tensor.dims()), std::end(tensor.dims())}
{
}
Tensor(const Tensor&) = default;
Tensor(Tensor&&) = default;
Tensor& operator=(const Tensor&) = delete;
Tensor& operator=(Tensor&&) = delete;
const Shape& get_shape() const { return m_shape; }
template <typename T>
std::vector<T> get_data() const
{
return detail::tensor::get_data<T>(m_tensor_proto);
}
const std::string& get_name() const
{
if (!m_tensor_proto.has_name())
{
throw error::tensor::unspecified_name{};
}
return m_tensor_proto.name();
}
Type get_type() const
{
if (!m_tensor_proto.has_data_type())
{
throw error::tensor::unspecified_data_type{};
}
return static_cast<Type>(m_tensor_proto.data_type());
}
const element::Type& get_ng_type() const
{
if (!m_tensor_proto.has_data_type())
{
throw error::tensor::unspecified_data_type{};
}
switch (m_tensor_proto.data_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default: throw error::tensor::unsupported_data_type{m_tensor_proto.data_type()};
}
}
operator onnx::TensorProto_DataType() const { return m_tensor_proto.data_type(); }
private:
const onnx::TensorProto& m_tensor_proto;
Shape m_shape;
};
inline std::ostream& operator<<(std::ostream& outs, const Tensor& tensor)
{
return (outs << "<Tensor: " << tensor.get_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <ostream>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter_vector.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx.pb.h"
#include "node.hpp"
#include "tensor.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace error
{
namespace value_info
{
struct unspecified_element_type : ngraph_error
{
unspecified_element_type()
: ngraph_error{"value info has not element type specified"}
{
}
};
struct unsupported_element_type : ngraph_error
{
explicit unsupported_element_type(onnx::TensorProto_DataType type)
: ngraph_error{"unsupported value info element type: " +
onnx::TensorProto_DataType_Name(type)}
{
}
};
}
}
class ValueInfo
{
public:
ValueInfo(ValueInfo&&) = default;
ValueInfo(const ValueInfo&) = default;
ValueInfo() = delete;
explicit ValueInfo(const onnx::ValueInfoProto& value_info_proto)
: m_value_info_proto{value_info_proto}
{
if (value_info_proto.type().has_tensor_type())
{
for (const auto& dim : value_info_proto.type().tensor_type().shape().dim())
{
m_shape.emplace_back(static_cast<Shape::value_type>(dim.dim_value()));
}
}
}
ValueInfo& operator=(const ValueInfo&) = delete;
ValueInfo& operator=(ValueInfo&&) = delete;
const std::string& get_name() const { return m_value_info_proto.name(); }
const Shape& get_shape() const { return m_shape; }
const element::Type& get_element_type() const
{
if (!m_value_info_proto.type().tensor_type().has_elem_type())
{
throw error::value_info::unspecified_element_type{};
}
switch (m_value_info_proto.type().tensor_type().elem_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: return element::boolean;
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: return element::f32;
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: return element::f64;
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8;
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: return element::i16;
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: return element::i32;
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: return element::i64;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: return element::u8;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: return element::u16;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: return element::u32;
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: return element::u64;
default:
throw error::value_info::unsupported_element_type{
m_value_info_proto.type().tensor_type().elem_type()};
}
}
std::shared_ptr<ngraph::Node>
get_ng_node(op::ParameterVector& parameters,
const std::map<std::string, Tensor>& initializers) const
{
const auto it{initializers.find(get_name())};
if (it != std::end(initializers))
{
return get_ng_constant(it->second);
}
else
{
parameters.push_back(get_ng_parameter());
return parameters.back();
}
}
protected:
std::shared_ptr<op::Parameter> get_ng_parameter() const
{
return std::make_shared<op::Parameter>(get_element_type(), get_shape());
}
std::shared_ptr<op::Constant> get_ng_constant(const Tensor& tensor) const
{
switch (m_value_info_proto.type().tensor_type().elem_type())
{
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return make_ng_constant<bool>(element::boolean, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return make_ng_constant<float>(element::f32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return make_ng_constant<double>(element::f64, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
return make_ng_constant<int8_t>(element::i8, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
return make_ng_constant<int16_t>(element::i16, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
return make_ng_constant<int32_t>(element::i32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
return make_ng_constant<int64_t>(element::i64, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return make_ng_constant<uint8_t>(element::u8, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return make_ng_constant<uint16_t>(element::u16, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return make_ng_constant<uint32_t>(element::u32, tensor);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return make_ng_constant<uint64_t>(element::u64, tensor);
default:
throw error::value_info::unsupported_element_type{
m_value_info_proto.type().tensor_type().elem_type()};
}
}
template <typename T>
std::shared_ptr<op::Constant> make_ng_constant(const element::Type& type,
const Tensor& tensor) const
{
return std::make_shared<op::Constant>(type, m_shape, tensor.get_data<T>());
}
private:
const onnx::ValueInfoProto& m_value_info_proto;
Shape m_shape;
};
inline std::ostream& operator<<(std::ostream& outs, const ValueInfo& info)
{
return (outs << "<ValueInfo: " << info.get_name() << ">");
}
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment