Unverified Commit 631d7253 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] Add support for Operator Sets (#1801)

* onnx: fix cmake commands list
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: change exception name to UnknownOperator
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: rename ops_bridge class
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: initial opset versions
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: check operator availability
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: opset versions - after reivew changes
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix doxygen comment fromat
parent bcb5a47b
......@@ -14,12 +14,12 @@
# limitations under the License.
# ******************************************************************************
set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set")
add_library(onnx_import_interface OBJECT
onnx.cpp
onnx.hpp)
target_include_directories(onnx_import_interface PRIVATE ${ONNX_PROTO_INCLUDE_DIR})
add_library(onnx_import STATIC
core/attribute.cpp
core/attribute.hpp
......@@ -29,6 +29,7 @@ add_library(onnx_import STATIC
core/model.hpp
core/node.cpp
core/node.hpp
core/operator_set.hpp
core/tensor.hpp
core/value_info.hpp
exceptions.hpp
......@@ -143,12 +144,17 @@ if (NOT NGRAPH_USE_SYSTEM_PROTOBUF)
endif()
set_property(TARGET onnx_import PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import PUBLIC ${CMAKE_CURRENT_BINARY_DIR} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_include_directories(onnx_import PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_link_libraries(onnx_import PRIVATE ${Protobuf_LIBRARIES} ${ONNX_PROTO_LIBRARY})
target_compile_definitions(onnx_import PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
set_property(TARGET onnx_import_interface PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR}
${NGRAPH_INCLUDE_PATH} ${Protobuf_INCLUDE_DIR})
target_include_directories(onnx_import_interface PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${NGRAPH_INCLUDE_PATH}
SYSTEM PRIVATE ${ONNX_INCLUDE_DIR} ${ONNX_PROTO_INCLUDE_DIR} ${Protobuf_INCLUDE_DIR})
target_compile_definitions(onnx_import_interface PRIVATE ONNX_OPSET_VERSION=${ONNX_OPSET_VERSION})
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "^(Apple)?Clang$")
target_compile_options(onnx_import PRIVATE -Wno-undef -Wno-reserved-id-macro -Wno-switch-enum
......
......@@ -16,17 +16,27 @@
#include "attribute.hpp"
#include "graph.hpp"
#include "operator_set.hpp"
namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array() const
std::vector<Graph> Attribute::get_graph_array(const OperatorSet& opset) const
{
return {std::begin(m_attribute_proto->graphs()), std::end(m_attribute_proto->graphs())};
std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
{
result.emplace_back(graph, opset);
}
return result;
}
Graph Attribute::get_graph(const OperatorSet& opset) const
{
return Graph{m_attribute_proto->g(), opset};
}
Graph Attribute::get_graph() const { return Graph{m_attribute_proto->g()}; }
} // namespace onnx_import
} // namespace ngraph
......@@ -19,6 +19,8 @@
#include <onnx-ml.pb.h>
#include "ngraph/except.hpp"
#include "operator_set.hpp"
#include "tensor.hpp"
#define likely(__x) __builtin_expect(!!(__x), 1)
......@@ -273,7 +275,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph() const;
Graph get_graph(const OperatorSet& opset) const;
std::vector<Tensor> get_tensor_array() const
{
......@@ -298,7 +300,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())};
}
std::vector<Graph> get_graph_array() const;
std::vector<Graph> get_graph_array(const OperatorSet&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
......
......@@ -14,6 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <set>
#include "graph.hpp"
#include "node.hpp"
......@@ -21,8 +23,22 @@ namespace ngraph
{
namespace onnx_import
{
Graph::Graph(const onnx::GraphProto& graph_proto)
namespace detail
{
std::string to_string(const std::set<std::string>& set)
{
std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it)
{
result += (it != std::begin(set) ? ", " : "") + *it;
}
return result;
}
}
Graph::Graph(const onnx::GraphProto& graph_proto, const OperatorSet& opset)
: m_graph_proto{&graph_proto}
, m_opset{&opset}
{
for (const auto& tensor : m_graph_proto->initializer())
{
......@@ -45,6 +61,20 @@ namespace ngraph
m_outputs.emplace_back(output);
}
// Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types;
for (const auto& node_proto : m_graph_proto->node())
{
auto it = m_opset->find(node_proto.op_type());
if (it == std::end(*m_opset))
{
unknown_operator_types.emplace(node_proto.op_type());
}
}
NGRAPH_ASSERT(unknown_operator_types.empty())
<< "unknown operations: " << detail::to_string(unknown_operator_types);
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node())
{
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/parameter_vector.hpp"
#include "operator_set.hpp"
#include "value_info.hpp"
namespace ngraph
......@@ -32,7 +33,7 @@ namespace ngraph
class Graph
{
public:
explicit Graph(const onnx::GraphProto& proto);
explicit Graph(const onnx::GraphProto& proto, const OperatorSet& opset);
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......@@ -44,6 +45,11 @@ namespace ngraph
}
const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const
{
return m_opset->at(node.op_type())(node);
}
private:
const onnx::GraphProto* m_graph_proto;
std::vector<Node> m_nodes;
......@@ -52,6 +58,7 @@ namespace ngraph
op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
const OperatorSet* m_opset;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
......@@ -14,15 +14,9 @@
// limitations under the License.
//*****************************************************************************
#include <ostream>
#include <set>
#include <onnx-ml.pb.h>
#include "assertion.hpp"
#include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
......@@ -31,30 +25,15 @@ namespace ngraph
Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto}
{
// Verify that the ONNX graph contains only nodes of supported op_type
assert_all_op_types_supported();
}
void Model::assert_all_op_types_supported()
{
std::set<std::string> unsupported_ops;
for (const auto& node_proto : get_graph().node())
for (const auto& id : m_model_proto->opset_import())
{
std::string op_type = node_proto.op_type();
if (!ops_bridge::is_op_type_supported(op_type))
// onnx.proto(.3): the empty string ("") or absence of this field implies
// the operator set that is defined as part of the ONNX specification.
if (id.domain().empty())
{
unsupported_ops.insert(op_type);
m_opset_version = id.version();
}
}
std::string unsupported_ops_str;
std::size_t index = 0;
for (const auto& op_type : unsupported_ops)
{
unsupported_ops_str += (index++ != 0 ? ", " : "");
unsupported_ops_str += op_type;
}
NGRAPH_ASSERT(unsupported_ops.empty()) << "unknown operations: " << unsupported_ops_str;
}
} // namespace onnx_import
......
......@@ -44,10 +44,10 @@ namespace ngraph
return m_model_proto->producer_version();
}
std::int64_t get_opset_version() const { return m_opset_version; }
private:
const onnx::ModelProto* m_model_proto;
void assert_all_op_types_supported();
std::int64_t m_opset_version{ONNX_OPSET_VERSION};
};
inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
......@@ -16,13 +16,12 @@
#include "node.hpp"
#include "graph.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
NodeVector Node::get_ng_nodes() const { return ops_bridge::make_ng_nodes(*this); }
NodeVector Node::get_ng_nodes() const { return m_graph->make_ng_nodes(*this); }
NodeVector Node::get_ng_inputs() const
{
NodeVector result;
......
......@@ -70,6 +70,7 @@ namespace ngraph
NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const;
const std::string& domain() const { return m_node_proto->domain(); }
const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include "ngraph/node_vector.hpp"
namespace ngraph
{
namespace onnx_import
{
// Forward declaration
class Node;
using Operator = std::function<NodeVector(const Node&)>;
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
} // namespace onnx_import
} // namespace ngraph
......@@ -21,7 +21,9 @@
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/node.hpp"
#include "onnx.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
......@@ -59,7 +61,8 @@ namespace ngraph
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph()};
Graph graph{model_proto.graph(),
ops_bridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs())
{
output_functions.emplace_back(std::make_shared<Function>(
......
......@@ -28,14 +28,12 @@ namespace ngraph
{
namespace set_1
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool
* operation.
*/
/// \brief Convert ONNX AveragePool operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX AveragePool
/// operation.
NodeVector average_pool(const Node& node);
} // namespace set_1
......
......@@ -14,9 +14,11 @@
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <functional>
#include <iterator>
#include <map>
#include <string>
#include <unordered_map>
#include "core/attribute.hpp"
#include "op/abs.hpp"
......@@ -89,47 +91,194 @@ namespace ngraph
{
namespace error
{
struct unknown_operation : ngraph_error
struct UnknownOperator : ngraph_error
{
explicit unknown_operation(const std::string& op_type)
: ngraph_error{"unknown operation: " + op_type}
explicit UnknownOperator(const std::string& op_type)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
{
}
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " +
std::to_string(version)}
{
}
};
} // namespace error
class ops_bridge
class OperatorsBridge
{
public:
ops_bridge(const ops_bridge&) = delete;
ops_bridge& operator=(const ops_bridge&) = delete;
ops_bridge(ops_bridge&&) = delete;
ops_bridge& operator=(ops_bridge&&) = delete;
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static NodeVector make_ng_nodes(const Node& node)
static const OperatorSet& get_operator_set(std::int64_t version)
{
return ops_bridge::get()(node);
return instance().get_operator_set_version(version);
}
static bool is_op_type_supported(const std::string& op_type)
private:
std::unordered_map<std::string,
std::map<std::int64_t, std::function<NodeVector(const Node&)>>>
m_map;
static const OperatorsBridge& instance()
{
return ops_bridge::get().is_op_type_supported_(op_type);
static OperatorsBridge instance;
return instance;
}
private:
std::map<std::string, std::function<NodeVector(const Node&)>> m_map;
const Operator& get_operator(const std::string& name, std::int64_t version) const
{
auto op = m_map.find(name);
if (op == std::end(m_map))
{
throw error::UnknownOperator{name};
}
auto it = op->second.find(version);
if (it == std::end(op->second))
{
throw error::UnsupportedVersion{version};
}
return it->second;
}
static const ops_bridge& get()
const OperatorSet& get_operator_set_version_1() const
{
static ops_bridge instance;
return instance;
static OperatorSet operator_set;
if (operator_set.empty())
{
for (const auto& op : m_map)
{
for (const auto& it : op.second)
{
if (it.first == 1)
{
operator_set.emplace(op.first, it.second);
}
}
}
}
return operator_set;
}
const OperatorSet& get_operator_set_version_2() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_1();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_3() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_2();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_4() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_3();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_5() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_4();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_6() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_5();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_7() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_6();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_8() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_7();
}
return operator_set;
}
const OperatorSet& get_operator_set_version_9() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_8();
}
return operator_set;
}
#define OPERATOR_SET_NAME(version_) get_operator_set_version_##version_()
#define GET_OPERATOR_SET(version_) \
case version_: \
return OPERATOR_SET_NAME(version_)
#define OPERATOR_SET_NAME_HELPER(version_) OPERATOR_SET_NAME(version_)
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const OperatorSet& get_operator_set_version(std::int64_t version) const
{
switch (version)
{
GET_OPERATOR_SET(1);
GET_OPERATOR_SET(2);
GET_OPERATOR_SET(3);
GET_OPERATOR_SET(4);
GET_OPERATOR_SET(5);
GET_OPERATOR_SET(6);
GET_OPERATOR_SET(7);
GET_OPERATOR_SET(8);
GET_OPERATOR_SET(9);
default: DEFAULT_OPERATOR_SET();
}
}
#define REGISTER_OPERATOR(name_, version_, fn_) \
m_map.emplace(name_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
ops_bridge()
OperatorsBridge()
{
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Add", 1, add);
......@@ -202,38 +351,15 @@ namespace ngraph
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
NodeVector operator()(const Node& node) const
{
auto it = m_map.find(node.op_type());
if (it == m_map.end())
{
throw detail::error::unknown_operation{node.op_type()};
}
std::function<NodeVector(const Node&)> factory{it->second};
return factory(node);
}
bool is_op_type_supported_(const std::string& op_type) const
{
auto it = m_map.find(op_type);
return !(it == m_map.end());
}
};
} // namespace detail
namespace ops_bridge
{
NodeVector make_ng_nodes(const Node& node)
{
return detail::ops_bridge::make_ng_nodes(node);
}
bool is_op_type_supported(const std::string& op_type)
const OperatorSet& get_operator_set(std::int64_t version)
{
return detail::ops_bridge::is_op_type_supported(op_type);
return detail::OperatorsBridge::get_operator_set(version);
}
} // namespace ops_bridge
......
......@@ -16,8 +16,9 @@
#pragma once
#include "core/node.hpp"
#include "ngraph/node_vector.hpp"
#include <cstdint>
#include "core/operator_set.hpp"
namespace ngraph
{
......@@ -25,9 +26,9 @@ namespace ngraph
{
namespace ops_bridge
{
NodeVector make_ng_nodes(const onnx_import::Node&);
bool is_op_type_supported(const std::string& op_type);
}
const OperatorSet& get_operator_set(std::int64_t version);
} // namespace ops_bridge
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment