Unverified Commit 631d7253 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] Add support for Operator Sets (#1801)

* onnx: fix cmake commands list
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: change exception name to UnknownOperator
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: rename ops_bridge class
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: initial opset versions
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: check operator availability
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: opset versions - after reivew changes
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix doxygen comment fromat
parent bcb5a47b
......@@ -14,12 +14,12 @@
# limitations under the License.
# ******************************************************************************
set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set")
add_library(onnx_import_interface OBJECT
onnx.cpp
onnx.hpp)
target_include_directories(onnx_import_interface PRIVATE ${ONNX_PROTO_INCLUDE_DIR})
add_library(onnx_import STATIC
core/attribute.cpp
core/attribute.hpp
......@@ -29,6 +29,7 @@ add_library(onnx_import STATIC
core/model.hpp
core/node.cpp
core/node.hpp
core/operator_set.hpp
core/tensor.hpp
core/value_info.hpp
exceptions.hpp
......@@ -143,12 +144,17 @@ if (NOT NGRAPH_USE_SYSTEM_PROTOBUF)
endif()
set_property(TARGET onnx_import PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import PUBLIC ${CMAKE_CURRENT_BINARY_DIR} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_include_directories(onnx_import PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_link_libraries(onnx_import PRIVATE ${Protobuf_LIBRARIES} ${ONNX_PROTO_LIBRARY})
target_compile_definitions(onnx_import PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
set_property(TARGET onnx_import_interface PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR}
${NGRAPH_INCLUDE_PATH} ${Protobuf_INCLUDE_DIR})
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_compile_definitions(onnx_import_interface PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
......
......@@ -16,17 +16,27 @@
#include "attribute.hpp"
#include "graph.hpp"
#include "operator_set.hpp"
namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array() const
std::vector<Graph> Attribute::get_graph_array(const OperatorSet& opset) const
{
return {std::begin(m_attribute_proto->graphs()), std::end(m_attribute_proto->graphs())};
std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
{
result.emplace_back(graph, opset);
}
return result;
}
Graph Attribute::get_graph(const OperatorSet& opset) const
{
return Graph{m_attribute_proto->g(), opset};
}
Graph Attribute::get_graph() const { return Graph{m_attribute_proto->g()}; }
} // namespace onnx_import
} // namespace ngraph
......@@ -19,6 +19,8 @@
#include <onnx-ml.pb.h>
#include "ngraph/except.hpp"
#include "operator_set.hpp"
#include "tensor.hpp"
#define likely(__x) __builtin_expect(!!(__x), 1)
......@@ -273,7 +275,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph() const;
Graph get_graph(const OperatorSet& opset) const;
std::vector<Tensor> get_tensor_array() const
{
......@@ -298,7 +300,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())};
}
std::vector<Graph> get_graph_array() const;
std::vector<Graph> get_graph_array(const OperatorSet&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
......
......@@ -14,6 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <set>
#include "graph.hpp"
#include "node.hpp"
......@@ -21,8 +23,22 @@ namespace ngraph
{
namespace onnx_import
{
Graph::Graph(const onnx::GraphProto& graph_proto)
namespace detail
{
std::string to_string(const std::set<std::string>& set)
{
std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it)
{
result += (it != std::begin(set) ? ", " : "") + *it;
}
return result;
}
}
Graph::Graph(const onnx::GraphProto& graph_proto, const OperatorSet& opset)
: m_graph_proto{&graph_proto}
, m_opset{&opset}
{
for (const auto& tensor : m_graph_proto->initializer())
{
......@@ -45,6 +61,20 @@ namespace ngraph
m_outputs.emplace_back(output);
}
// Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types;
for (const auto& node_proto : m_graph_proto->node())
{
auto it = m_opset->find(node_proto.op_type());
if (it == std::end(*m_opset))
{
unknown_operator_types.emplace(node_proto.op_type());
}
}
NGRAPH_ASSERT(unknown_operator_types.empty())
<< "unknown operations: " << detail::to_string(unknown_operator_types);
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node())
{
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/parameter_vector.hpp"
#include "operator_set.hpp"
#include "value_info.hpp"
namespace ngraph
......@@ -32,7 +33,7 @@ namespace ngraph
class Graph
{
public:
explicit Graph(const onnx::GraphProto& proto);
explicit Graph(const onnx::GraphProto& proto, const OperatorSet& opset);
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......@@ -44,6 +45,11 @@ namespace ngraph
}
const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const
{
return m_opset->at(node.op_type())(node);
}
private:
const onnx::GraphProto* m_graph_proto;
std::vector<Node> m_nodes;
......@@ -52,6 +58,7 @@ namespace ngraph
op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
const OperatorSet* m_opset;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
......@@ -14,15 +14,9 @@
// limitations under the License.
//*****************************************************************************
#include <ostream>
#include <set>
#include <onnx-ml.pb.h>
#include "assertion.hpp"
#include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
......@@ -31,30 +25,15 @@ namespace ngraph
Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto}
{
// Verify that the ONNX graph contains only nodes of supported op_type
assert_all_op_types_supported();
}
void Model::assert_all_op_types_supported()
{
std::set<std::string> unsupported_ops;
for (const auto& node_proto : get_graph().node())
for (const auto& id : m_model_proto->opset_import())
{
std::string op_type = node_proto.op_type();
if (!ops_bridge::is_op_type_supported(op_type))
// onnx.proto(.3): the empty string ("") or absence of this field implies
// the operator set that is defined as part of the ONNX specification.
if (id.domain().empty())
{
unsupported_ops.insert(op_type);
m_opset_version = id.version();
}
}
std::string unsupported_ops_str;
std::size_t index = 0;
for (const auto& op_type : unsupported_ops)
{
unsupported_ops_str += (index++ != 0 ? ", " : "");
unsupported_ops_str += op_type;
}
NGRAPH_ASSERT(unsupported_ops.empty()) << "unknown operations: " << unsupported_ops_str;
}
} // namespace onnx_import
......
......@@ -44,10 +44,10 @@ namespace ngraph
return m_model_proto->producer_version();
}
std::int64_t get_opset_version() const { return m_opset_version; }
private:
const onnx::ModelProto* m_model_proto;
void assert_all_op_types_supported();
std::int64_t m_opset_version{ONNX_OPSET_VERSION};
};
inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
......@@ -16,13 +16,12 @@
#include "node.hpp"
#include "graph.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
NodeVector Node::get_ng_nodes() const { return ops_bridge::make_ng_nodes(*this); }
NodeVector Node::get_ng_nodes() const { return m_graph->make_ng_nodes(*this); }
NodeVector Node::get_ng_inputs() const
{
NodeVector result;
......
......@@ -70,6 +70,7 @@ namespace ngraph
NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const;
const std::string& domain() const { return m_node_proto->domain(); }
const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
// Forward declaration
class Node;
using Operator = std::function<NodeVector(const Node&)>;
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
} // namespace onnx_import
} // namespace ngraph
......@@ -21,7 +21,9 @@
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/node.hpp"
#include "onnx.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
......@@ -59,7 +61,8 @@ namespace ngraph
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph()};
Graph graph{model_proto.graph(),
ops_bridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs())
{
output_functions.emplace_back(std::make_shared<Function>(
......
......@@ -28,14 +28,12 @@ namespace ngraph
{
namespace set_1
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool
* operation.
*/
/// \brief Convert ONNX AveragePool operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
/// operation.
NodeVector average_pool(const Node& node);
} // namespace set_1
......
......@@ -16,8 +16,9 @@
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include <cstdint>
#include "core/operator_set.hpp"
namespace ngraph
{
......@@ -25,9 +26,9 @@ namespace ngraph
{
namespace ops_bridge
{
NodeVector make_ng_nodes(const onnx_import::Node&);
bool is_op_type_supported(const std::string& op_type);
}
const OperatorSet& get_operator_set(std::int64_t version);
} // namespace ops_bridge
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment