Unverified Commit e7b4106e authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] add the ability to register custom ONNX operators (#1856)

* onnx: add information about a domain to operators set
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: updates after review
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: update comments in the code
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix bug in node's description method
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix CentOS compilation
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: more after review changes
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 8855c723
......@@ -18,7 +18,10 @@ set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set"
add_library(onnx_import_interface OBJECT
onnx.cpp
onnx.hpp)
onnx.hpp
core/operator_set.hpp
core/node.cpp
core/node.hpp)
add_library(onnx_import STATIC
core/attribute.cpp
......@@ -27,7 +30,6 @@ add_library(onnx_import STATIC
core/graph.hpp
core/model.cpp
core/model.hpp
core/node.cpp
core/node.hpp
core/operator_set.hpp
core/tensor.hpp
......
......@@ -16,25 +16,25 @@
#include "attribute.hpp"
#include "graph.hpp"
#include "operator_set.hpp"
#include "model.hpp"
namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array(const OperatorSet& opset) const
std::vector<Graph> Attribute::get_graph_array(const Model& model) const
{
std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
{
result.emplace_back(graph, opset);
result.emplace_back(graph, model);
}
return result;
}
Graph Attribute::get_graph(const OperatorSet& opset) const
Graph Attribute::get_graph(const Model& model) const
{
return Graph{m_attribute_proto->g(), opset};
return Graph{m_attribute_proto->g(), model};
}
} // namespace onnx_import
......
......@@ -20,15 +20,15 @@
#include "ngraph/except.hpp"
#include "operator_set.hpp"
#include "tensor.hpp"
namespace ngraph
{
namespace onnx_import
{
// forward declaration
// forward declarations
class Graph;
class Model;
namespace error
{
......@@ -272,7 +272,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph(const OperatorSet& opset) const;
Graph get_graph(const Model&) const;
std::vector<Tensor> get_tensor_array() const
{
......@@ -297,7 +297,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())};
}
std::vector<Graph> get_graph_array(const OperatorSet&) const;
std::vector<Graph> get_graph_array(const Model&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
......
......@@ -34,11 +34,17 @@ namespace ngraph
}
return result;
}
inline std::string to_string(const onnx::NodeProto& node_proto)
{
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") +
node_proto.op_type();
}
}
Graph::Graph(const onnx::GraphProto& graph_proto, const OperatorSet& opset)
Graph::Graph(const onnx::GraphProto& graph_proto, const Model& model)
: m_graph_proto{&graph_proto}
, m_opset{&opset}
, m_model{&model}
{
for (const auto& tensor : m_graph_proto->initializer())
{
......@@ -65,10 +71,9 @@ namespace ngraph
std::set<std::string> unknown_operator_types;
for (const auto& node_proto : m_graph_proto->node())
{
auto it = m_opset->find(node_proto.op_type());
if (it == std::end(*m_opset))
if (!m_model->is_operator_available(node_proto))
{
unknown_operator_types.emplace(node_proto.op_type());
unknown_operator_types.emplace(detail::to_string(node_proto));
}
}
......@@ -78,7 +83,7 @@ namespace ngraph
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node())
{
m_nodes.emplace_back(node_proto, this);
m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()};
NodeVector ng_nodes{node.get_ng_nodes()};
for (int i = 0; i < ng_nodes.size(); i++)
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/parameter_vector.hpp"
#include "model.hpp"
#include "operator_set.hpp"
#include "value_info.hpp"
......@@ -33,7 +34,7 @@ namespace ngraph
class Graph
{
public:
explicit Graph(const onnx::GraphProto& proto, const OperatorSet& opset);
Graph(const onnx::GraphProto& proto, const Model& model);
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......@@ -47,7 +48,7 @@ namespace ngraph
const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const
{
return m_opset->at(node.op_type())(node);
return m_model->get_operator(node.op_type(), node.domain())(node);
}
private:
......@@ -58,7 +59,7 @@ namespace ngraph
op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
const OperatorSet* m_opset;
const Model* m_model;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
......@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h>
#include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
......@@ -25,15 +26,49 @@ namespace ngraph
Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto}
{
// Walk through the elements of opset_import field and register operator sets
// for each domain. An exception UnknownDomain() will raise if the domain is
// unknown or invalid.
for (const auto& id : m_model_proto->opset_import())
{
// onnx.proto(.3): the empty string ("") or absence of this field implies
// the operator set that is defined as part of the ONNX specification.
if (id.domain().empty())
{
m_opset_version = id.version();
}
m_opset.emplace(id.domain(),
OperatorsBridge::get_operator_set(
id.version(), (id.domain() == "ai.onnx" ? "" : id.domain())));
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
const auto dm = m_opset.find("");
if (dm == std::end(m_opset))
{
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, ""));
}
}
const Operator& Model::get_operator(const std::string& name,
const std::string& domain) const
{
const auto dm = m_opset.find(domain);
if (dm == std::end(m_opset))
{
throw error::UnknownDomain{domain};
}
const auto op = dm->second.find(name);
if (op == std::end(dm->second))
{
throw error::UnknownOperator{name, domain};
}
return op->second;
}
bool Model::is_operator_available(const onnx::NodeProto& node_proto) const
{
const auto dm = m_opset.find(node_proto.domain());
if (dm == std::end(m_opset))
{
return false;
}
const auto op = dm->second.find(node_proto.op_type());
return (op != std::end(dm->second));
}
} // namespace onnx_import
......
......@@ -16,9 +16,13 @@
#pragma once
#include <onnx-ml.pb.h>
#include <ostream>
#include <string>
#include <unordered_map>
#include <onnx-ml.pb.h>
#include "operator_set.hpp"
namespace ngraph
{
......@@ -30,11 +34,11 @@ namespace ngraph
Model() = delete;
explicit Model(const onnx::ModelProto& model_proto);
Model(Model&&) noexcept = default;
Model(const Model&) = default;
Model(Model&&) = default;
Model& operator=(Model&&) noexcept = delete;
Model& operator=(const Model&) = delete;
Model& operator=(Model&&) = delete;
const std::string& get_producer_name() const { return m_model_proto->producer_name(); }
const onnx::GraphProto& get_graph() const { return m_model_proto->graph(); }
......@@ -44,10 +48,23 @@ namespace ngraph
return m_model_proto->producer_version();
}
std::int64_t get_opset_version() const { return m_opset_version; }
/// \brief Access an operator object by its type name and domain name
/// The function will return the operator object if it exists, or report an error
/// in case of domain or operator absence.
/// \param name type name of the operator object,
/// \param domain domain name of the operator object.
/// \return Reference to the operator object.
/// \throw error::UnknownDomain there is no operator set defined for the given domain,
/// \throw error::UnknownOperator the given operator type name does not exist in operator set.
const Operator& get_operator(const std::string& name, const std::string& domain) const;
/// \brief Check availability of operator base on NodeProto.
/// \return `true` if the operator is available, otherwise it returns `false`.
bool is_operator_available(const onnx::NodeProto& node_proto) const;
private:
const onnx::ModelProto* m_model_proto;
std::int64_t m_opset_version{ONNX_OPSET_VERSION};
std::unordered_map<std::string, OperatorSet> m_opset;
};
inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
......@@ -18,12 +18,14 @@
#include <string>
#include <onnx-ml.pb.h>
#include "ngraph/except.hpp"
#include "ngraph/node_vector.hpp"
#include "attribute.hpp"
#include "tensor.hpp"
namespace onnx
{
// forward declaration
class NodeProto;
}
namespace ngraph
{
......@@ -52,70 +54,41 @@ namespace ngraph
{
public:
Node() = delete;
Node(const onnx::NodeProto& node_proto, const Graph* graph)
: m_node_proto{&node_proto}
, m_graph{graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
{
}
Node(const onnx::NodeProto& node_proto, const Graph& graph);
Node(Node&&) noexcept = default;
Node(const Node&) = default;
Node(Node&&) noexcept;
Node(const Node&);
Node& operator=(Node&&) noexcept = delete;
Node& operator=(const Node&) = delete;
const std::vector<Attribute>& attributes() const { return m_attributes; }
NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const;
NodeVector get_ng_nodes() const;
const std::string& domain() const;
const std::string& op_type() const;
const std::string& get_name() const;
const std::string& domain() const { return m_node_proto->domain(); }
const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
/// \brief Describe the ONNX Node to make debugging graphs easier
/// Function will return the Node's name if it has one, or the names of its outputs.
/// \return Description of Node
std::string get_description() const;
const std::string& get_description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const
{
return m_output_names;
}
const std::string& output(int index) const { return m_node_proto->output(index); }
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const
{
auto it = std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; });
if (it == std::end(m_attributes))
{
return default_value;
}
return it->template get_value<T>();
}
T get_attribute_value(const std::string& name, T default_value) const;
template <typename T>
T get_attribute_value(const std::string& name) const
{
auto it = std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; });
if (it == std::end(m_attributes))
{
throw error::node::UnknownAttribute{get_name(), name};
}
return it->template get_value<T>();
}
T get_attribute_value(const std::string& name) const;
private:
const onnx::NodeProto* m_node_proto;
const Graph* m_graph;
std::vector<Attribute> m_attributes;
std::vector<std::reference_wrapper<const std::string>> m_output_names;
class Impl;
// In this case we need custom deleter, because Impl is an incomplete
// type. Node's are elements of std::vector. Without custom deleter
// compilation fails; the compiler is unable to parameterize an allocator's
// default deleter due to incomple type.
std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
};
inline std::ostream& operator<<(std::ostream& outs, const Node& node)
......
......@@ -22,13 +22,12 @@
#include "ngraph/node_vector.hpp"
#include "node.hpp"
namespace ngraph
{
namespace onnx_import
{
// Forward declaration
class Node;
using Operator = std::function<NodeVector(const Node&)>;
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
......
......@@ -61,8 +61,7 @@ namespace ngraph
}
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph(),
OperatorsBridge::get_operator_set(model.get_opset_version())};
Graph graph{model_proto.graph(), model};
for (const auto& output : graph.get_outputs())
{
output_functions.emplace_back(std::make_shared<Function>(
......@@ -91,6 +90,14 @@ namespace ngraph
return load_onnx_model(path).front();
}
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
OperatorsBridge::register_operator(name, version, domain, std::move(fn));
}
} // namespace onnx_import
} // namespace ngraph
......@@ -21,10 +21,18 @@
#include "ngraph/function.hpp"
#include "core/operator_set.hpp"
namespace ngraph
{
namespace onnx_import
{
// Registers ONNX custom operator
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
// Convert on ONNX model to a vector of nGraph Functions (input stream)
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream&);
......
......@@ -87,134 +87,51 @@ namespace ngraph
{
namespace onnx_import
{
const OperatorSet& OperatorsBridge::get_operator_set_version_1() const
namespace detail
{
static OperatorSet operator_set;
if (operator_set.empty())
const Operator& find(const std::string& name,
std::int64_t version,
const std::string& domain,
const std::map<std::int64_t, Operator>& map)
{
for (const auto& op : m_map)
while (version > 0)
{
for (const auto& it : op.second)
const auto it = map.find(version--);
if (it != std::end(map))
{
if (it.first == 1)
{
operator_set.emplace(op.first, it.second);
}
return it->second;
}
}
throw error::UnsupportedVersion{name, version, domain};
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_2() const
void OperatorsBridge::_register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_1();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_3() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_2();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_4() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_3();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_5() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_4();
}
return operator_set;
m_map[domain][name].emplace(version, std::move(fn));
}
const OperatorSet& OperatorsBridge::get_operator_set_version_6() const
OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version,
const std::string& domain)
{
static OperatorSet operator_set;
if (operator_set.empty())
OperatorSet result;
auto dm = m_map.find(domain);
if (dm == std::end(m_map))
{
operator_set = get_operator_set_version_5();
throw error::UnknownDomain{domain};
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_7() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_6();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_8() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_7();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_9() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_8();
}
return operator_set;
}
#define OPERATOR_SET_NAME(version_) get_operator_set_version_##version_()
#define GET_OPERATOR_SET(version_) \
case version_: \
return OPERATOR_SET_NAME(version_)
#define OPERATOR_SET_NAME_HELPER(version_) OPERATOR_SET_NAME(version_)
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const OperatorSet& OperatorsBridge::get_operator_set_version(std::int64_t version) const
{
switch (version)
for (const auto& op : dm->second)
{
GET_OPERATOR_SET(1);
GET_OPERATOR_SET(2);
GET_OPERATOR_SET(3);
GET_OPERATOR_SET(4);
GET_OPERATOR_SET(5);
GET_OPERATOR_SET(6);
GET_OPERATOR_SET(7);
GET_OPERATOR_SET(8);
GET_OPERATOR_SET(9);
default: DEFAULT_OPERATOR_SET();
result.emplace(op.first, detail::find(op.first, version, domain, op.second));
}
return result;
}
#define REGISTER_OPERATOR(name_, version_, fn_) \
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
#define REGISTER_OPERATOR(name_, ver_, fn_) \
m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
OperatorsBridge::OperatorsBridge()
{
......
......@@ -33,16 +33,27 @@ namespace ngraph
{
struct UnknownOperator : ngraph_error
{
explicit UnknownOperator(const std::string& op_type)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
UnknownOperator(const std::string& name, const std::string& domain)
: ngraph_error{(domain.empty() ? "" : domain + ".") + name}
{
}
};
struct UnknownDomain : ngraph_error
{
explicit UnknownDomain(const std::string& domain)
: ngraph_error{domain}
{
}
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " + std::to_string(version)}
UnsupportedVersion(const std::string& name,
std::int64_t version,
const std::string& domain)
: ngraph_error{(domain.empty() ? "" : domain + ".") + name + ":" +
std::to_string(version)}
{
}
};
......@@ -57,32 +68,37 @@ namespace ngraph
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version)
static OperatorSet get_operator_set(std::int64_t version, const std::string& domain)
{
return instance()._get_operator_set(version, domain);
}
static void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
return instance().get_operator_set_version(version);
instance()._register_operator(name, version, domain, std::move(fn));
}
private:
std::unordered_map<std::string, std::map<std::int64_t, Operator>> m_map;
std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
m_map;
OperatorsBridge();
static const OperatorsBridge& instance()
static OperatorsBridge& instance()
{
static OperatorsBridge instance;
return instance;
}
const OperatorSet& get_operator_set_version_1() const;
const OperatorSet& get_operator_set_version_2() const;
const OperatorSet& get_operator_set_version_3() const;
const OperatorSet& get_operator_set_version_4() const;
const OperatorSet& get_operator_set_version_5() const;
const OperatorSet& get_operator_set_version_6() const;
const OperatorSet& get_operator_set_version_7() const;
const OperatorSet& get_operator_set_version_8() const;
const OperatorSet& get_operator_set_version_9() const;
const OperatorSet& get_operator_set_version(std::int64_t version) const;
void _register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
};
} // namespace onnx_import
......
......@@ -1297,3 +1297,39 @@ TEST(onnx, model_unsupported_op)
FAIL() << "Expected ngraph::ngraph_error";
}
}
TEST(onnx, model_custom_op)
{
onnx_import::register_operator(
"AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.onnx"));
Inputs inputs{{1, 2, 3, 4}};
Outputs expected_outputs{{3, 6, 9, 12}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_custom_op_default_domain)
{
onnx_import::register_operator(
"AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator_default_domain.onnx"));
Inputs inputs{{1, 2, 3, 4}};
Outputs expected_outputs{{3, 6, 9, 12}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment