Unverified Commit 631d7253 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] Add support for Operator Sets (#1801)

* onnx: fix cmake commands list
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: change exception name to UnknownOperator
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: rename ops_bridge class
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: initial opset versions
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: check operator availability
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: opset versions - after reivew changes
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix doxygen comment fromat
parent bcb5a47b
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
# limitations under the License. # limitations under the License.
# ****************************************************************************** # ******************************************************************************
set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set")
add_library(onnx_import_interface OBJECT add_library(onnx_import_interface OBJECT
onnx.cpp onnx.cpp
onnx.hpp) onnx.hpp)
target_include_directories(onnx_import_interface PRIVATE ${ONNX_PROTO_INCLUDE_DIR})
add_library(onnx_import STATIC add_library(onnx_import STATIC
core/attribute.cpp core/attribute.cpp
core/attribute.hpp core/attribute.hpp
...@@ -29,6 +29,7 @@ add_library(onnx_import STATIC ...@@ -29,6 +29,7 @@ add_library(onnx_import STATIC
core/model.hpp core/model.hpp
core/node.cpp core/node.cpp
core/node.hpp core/node.hpp
core/operator_set.hpp
core/tensor.hpp core/tensor.hpp
core/value_info.hpp core/value_info.hpp
exceptions.hpp exceptions.hpp
...@@ -143,12 +144,17 @@ if (NOT NGRAPH_USE_SYSTEM_PROTOBUF) ...@@ -143,12 +144,17 @@ if (NOT NGRAPH_USE_SYSTEM_PROTOBUF)
endif() endif()
set_property(TARGET onnx_import PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET onnx_import PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import PUBLIC ${CMAKE_CURRENT_BINARY_DIR} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR}) target_include_directories(onnx_import PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_link_libraries(onnx_import PRIVATE ${Protobuf_LIBRARIES} ${ONNX_PROTO_LIBRARY}) target_link_libraries(onnx_import PRIVATE ${Protobuf_LIBRARIES} ${ONNX_PROTO_LIBRARY})
target_compile_definitions(onnx_import PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
set_property(TARGET onnx_import_interface PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET onnx_import_interface PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR} target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${NGRAPH_INCLUDE_PATH}
${NGRAPH_INCLUDE_PATH} ${Protobuf_INCLUDE_DIR}) SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_compile_definitions(onnx_import_interface PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$") if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
......
...@@ -16,17 +16,27 @@ ...@@ -16,17 +16,27 @@
#include "attribute.hpp" #include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
#include "operator_set.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
std::vector<Graph> Attribute::get_graph_array() const std::vector<Graph> Attribute::get_graph_array(const OperatorSet& opset) const
{ {
return {std::begin(m_attribute_proto->graphs()), std::end(m_attribute_proto->graphs())}; std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
{
result.emplace_back(graph, opset);
}
return result;
}
Graph Attribute::get_graph(const OperatorSet& opset) const
{
return Graph{m_attribute_proto->g(), opset};
} }
Graph Attribute::get_graph() const { return Graph{m_attribute_proto->g()}; }
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "operator_set.hpp"
#include "tensor.hpp" #include "tensor.hpp"
#define likely(__x) __builtin_expect(!!(__x), 1) #define likely(__x) __builtin_expect(!!(__x), 1)
...@@ -273,7 +275,7 @@ namespace ngraph ...@@ -273,7 +275,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); } float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); } int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); } const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph() const; Graph get_graph(const OperatorSet& opset) const;
std::vector<Tensor> get_tensor_array() const std::vector<Tensor> get_tensor_array() const
{ {
...@@ -298,7 +300,7 @@ namespace ngraph ...@@ -298,7 +300,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())}; std::end(m_attribute_proto->strings())};
} }
std::vector<Graph> get_graph_array() const; std::vector<Graph> get_graph_array(const OperatorSet&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const /* explicit */ operator onnx::AttributeProto_AttributeType() const
{ {
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <set>
#include "graph.hpp" #include "graph.hpp"
#include "node.hpp" #include "node.hpp"
...@@ -21,8 +23,22 @@ namespace ngraph ...@@ -21,8 +23,22 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
Graph::Graph(const onnx::GraphProto& graph_proto) namespace detail
{
std::string to_string(const std::set<std::string>& set)
{
std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it)
{
result += (it != std::begin(set) ? ", " : "") + *it;
}
return result;
}
}
Graph::Graph(const onnx::GraphProto& graph_proto, const OperatorSet& opset)
: m_graph_proto{&graph_proto} : m_graph_proto{&graph_proto}
, m_opset{&opset}
{ {
for (const auto& tensor : m_graph_proto->initializer()) for (const auto& tensor : m_graph_proto->initializer())
{ {
...@@ -45,6 +61,20 @@ namespace ngraph ...@@ -45,6 +61,20 @@ namespace ngraph
m_outputs.emplace_back(output); m_outputs.emplace_back(output);
} }
// Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types;
for (const auto& node_proto : m_graph_proto->node())
{
auto it = m_opset->find(node_proto.op_type());
if (it == std::end(*m_opset))
{
unknown_operator_types.emplace(node_proto.op_type());
}
}
NGRAPH_ASSERT(unknown_operator_types.empty())
<< "unknown operations: " << detail::to_string(unknown_operator_types);
// Process ONNX graph nodes, convert to nGraph nodes // Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node()) for (const auto& node_proto : m_graph_proto->node())
{ {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/op/parameter_vector.hpp" #include "ngraph/op/parameter_vector.hpp"
#include "operator_set.hpp"
#include "value_info.hpp" #include "value_info.hpp"
namespace ngraph namespace ngraph
...@@ -32,7 +33,7 @@ namespace ngraph ...@@ -32,7 +33,7 @@ namespace ngraph
class Graph class Graph
{ {
public: public:
explicit Graph(const onnx::GraphProto& proto); explicit Graph(const onnx::GraphProto& proto, const OperatorSet& opset);
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
...@@ -44,6 +45,11 @@ namespace ngraph ...@@ -44,6 +45,11 @@ namespace ngraph
} }
const std::string& get_name() const { return m_graph_proto->name(); } const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const
{
return m_opset->at(node.op_type())(node);
}
private: private:
const onnx::GraphProto* m_graph_proto; const onnx::GraphProto* m_graph_proto;
std::vector<Node> m_nodes; std::vector<Node> m_nodes;
...@@ -52,6 +58,7 @@ namespace ngraph ...@@ -52,6 +58,7 @@ namespace ngraph
op::ParameterVector m_parameters; op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache; std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers; std::map<std::string, Tensor> m_initializers;
const OperatorSet* m_opset;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
...@@ -14,15 +14,9 @@ ...@@ -14,15 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <ostream>
#include <set>
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "assertion.hpp"
#include "model.hpp" #include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,30 +25,15 @@ namespace ngraph ...@@ -31,30 +25,15 @@ namespace ngraph
Model::Model(const onnx::ModelProto& model_proto) Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto} : m_model_proto{&model_proto}
{ {
// Verify that the ONNX graph contains only nodes of supported op_type for (const auto& id : m_model_proto->opset_import())
assert_all_op_types_supported();
}
void Model::assert_all_op_types_supported()
{
std::set<std::string> unsupported_ops;
for (const auto& node_proto : get_graph().node())
{ {
std::string op_type = node_proto.op_type(); // onnx.proto(.3): the empty string ("") or absence of this field implies
if (!ops_bridge::is_op_type_supported(op_type)) // the operator set that is defined as part of the ONNX specification.
if (id.domain().empty())
{ {
unsupported_ops.insert(op_type); m_opset_version = id.version();
} }
} }
std::string unsupported_ops_str;
std::size_t index = 0;
for (const auto& op_type : unsupported_ops)
{
unsupported_ops_str += (index++ != 0 ? ", " : "");
unsupported_ops_str += op_type;
}
NGRAPH_ASSERT(unsupported_ops.empty()) << "unknown operations: " << unsupported_ops_str;
} }
} // namespace onnx_import } // namespace onnx_import
......
...@@ -44,10 +44,10 @@ namespace ngraph ...@@ -44,10 +44,10 @@ namespace ngraph
return m_model_proto->producer_version(); return m_model_proto->producer_version();
} }
std::int64_t get_opset_version() const { return m_opset_version; }
private: private:
const onnx::ModelProto* m_model_proto; const onnx::ModelProto* m_model_proto;
std::int64_t m_opset_version{ONNX_OPSET_VERSION};
void assert_all_op_types_supported();
}; };
inline std::ostream& operator<<(std::ostream& outs, const Model& model) inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
...@@ -16,13 +16,12 @@ ...@@ -16,13 +16,12 @@
#include "node.hpp" #include "node.hpp"
#include "graph.hpp" #include "graph.hpp"
#include "ops_bridge.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
NodeVector Node::get_ng_nodes() const { return ops_bridge::make_ng_nodes(*this); } NodeVector Node::get_ng_nodes() const { return m_graph->make_ng_nodes(*this); }
NodeVector Node::get_ng_inputs() const NodeVector Node::get_ng_inputs() const
{ {
NodeVector result; NodeVector result;
......
...@@ -70,6 +70,7 @@ namespace ngraph ...@@ -70,6 +70,7 @@ namespace ngraph
NodeVector get_ng_nodes() const; NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const; NodeVector get_ng_inputs() const;
const std::string& domain() const { return m_node_proto->domain(); }
const std::string& op_type() const { return m_node_proto->op_type(); } const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); } const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier /// @brief Describe the ONNX Node to make debugging graphs easier
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
// Forward declaration
class Node;
using Operator = std::function<NodeVector(const Node&)>;
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
} // namespace onnx_import
} // namespace ngraph
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "core/graph.hpp" #include "core/graph.hpp"
#include "core/model.hpp" #include "core/model.hpp"
#include "core/node.hpp" #include "core/node.hpp"
#include "onnx.hpp" #include "onnx.hpp"
#include "ops_bridge.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -59,7 +61,8 @@ namespace ngraph ...@@ -59,7 +61,8 @@ namespace ngraph
} }
std::vector<std::shared_ptr<Function>> output_functions; std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto}; Model model{model_proto};
Graph graph{model_proto.graph()}; Graph graph{model_proto.graph(),
ops_bridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs()) for (const auto& output : graph.get_outputs())
{ {
output_functions.emplace_back(std::make_shared<Function>( output_functions.emplace_back(std::make_shared<Function>(
......
...@@ -28,14 +28,12 @@ namespace ngraph ...@@ -28,14 +28,12 @@ namespace ngraph
{ {
namespace set_1 namespace set_1
{ {
/** /// \brief Convert ONNX AveragePool operation to an nGraph node.
* @brief Convert ONNX AveragePool operation to an nGraph node. ///
* /// \param node The ONNX node object representing this operation.
* @param node The ONNX node object representing this operation. ///
* /// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool /// operation.
* operation.
*/
NodeVector average_pool(const Node& node); NodeVector average_pool(const Node& node);
} // namespace set_1 } // namespace set_1
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#pragma once #pragma once
#include "core/node.hpp" #include <cstdint>
#include "ngraph/node_vector.hpp"
#include "core/operator_set.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,9 +26,9 @@ namespace ngraph ...@@ -25,9 +26,9 @@ namespace ngraph
{ {
namespace ops_bridge namespace ops_bridge
{ {
NodeVector make_ng_nodes(const onnx_import::Node&); const OperatorSet& get_operator_set(std::int64_t version);
bool is_op_type_supported(const std::string& op_type);
} } // namespace ops_bridge
} // namespace onnx_import } // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment