Unverified Commit e7b4106e authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

[ONNX] add the ability to register custom ONNX operators (#1856)

* onnx: add information about a domain to operators set
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: updates after review
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: update comments in the code
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix bug in node's description method
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: fix CentOS compilation
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: more after review changes
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 8855c723
...@@ -18,7 +18,10 @@ set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set" ...@@ -18,7 +18,10 @@ set(ONNX_OPSET_VERSION 9 CACHE INTERNAL "Supported version of ONNX operator set"
add_library(onnx_import_interface OBJECT add_library(onnx_import_interface OBJECT
onnx.cpp onnx.cpp
onnx.hpp) onnx.hpp
core/operator_set.hpp
core/node.cpp
core/node.hpp)
add_library(onnx_import STATIC add_library(onnx_import STATIC
core/attribute.cpp core/attribute.cpp
...@@ -27,7 +30,6 @@ add_library(onnx_import STATIC ...@@ -27,7 +30,6 @@ add_library(onnx_import STATIC
core/graph.hpp core/graph.hpp
core/model.cpp core/model.cpp
core/model.hpp core/model.hpp
core/node.cpp
core/node.hpp core/node.hpp
core/operator_set.hpp core/operator_set.hpp
core/tensor.hpp core/tensor.hpp
......
...@@ -16,25 +16,25 @@ ...@@ -16,25 +16,25 @@
#include "attribute.hpp" #include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
#include "operator_set.hpp" #include "model.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
std::vector<Graph> Attribute::get_graph_array(const OperatorSet& opset) const std::vector<Graph> Attribute::get_graph_array(const Model& model) const
{ {
std::vector<Graph> result; std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs()) for (const auto& graph : m_attribute_proto->graphs())
{ {
result.emplace_back(graph, opset); result.emplace_back(graph, model);
} }
return result; return result;
} }
Graph Attribute::get_graph(const OperatorSet& opset) const Graph Attribute::get_graph(const Model& model) const
{ {
return Graph{m_attribute_proto->g(), opset}; return Graph{m_attribute_proto->g(), model};
} }
} // namespace onnx_import } // namespace onnx_import
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "operator_set.hpp"
#include "tensor.hpp" #include "tensor.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
// forward declaration // forward declarations
class Graph; class Graph;
class Model;
namespace error namespace error
{ {
...@@ -272,7 +272,7 @@ namespace ngraph ...@@ -272,7 +272,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); } float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); } int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); } const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph(const OperatorSet& opset) const; Graph get_graph(const Model&) const;
std::vector<Tensor> get_tensor_array() const std::vector<Tensor> get_tensor_array() const
{ {
...@@ -297,7 +297,7 @@ namespace ngraph ...@@ -297,7 +297,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())}; std::end(m_attribute_proto->strings())};
} }
std::vector<Graph> get_graph_array(const OperatorSet&) const; std::vector<Graph> get_graph_array(const Model&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const /* explicit */ operator onnx::AttributeProto_AttributeType() const
{ {
......
...@@ -34,11 +34,17 @@ namespace ngraph ...@@ -34,11 +34,17 @@ namespace ngraph
} }
return result; return result;
} }
inline std::string to_string(const onnx::NodeProto& node_proto)
{
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") +
node_proto.op_type();
}
} }
Graph::Graph(const onnx::GraphProto& graph_proto, const OperatorSet& opset) Graph::Graph(const onnx::GraphProto& graph_proto, const Model& model)
: m_graph_proto{&graph_proto} : m_graph_proto{&graph_proto}
, m_opset{&opset} , m_model{&model}
{ {
for (const auto& tensor : m_graph_proto->initializer()) for (const auto& tensor : m_graph_proto->initializer())
{ {
...@@ -65,10 +71,9 @@ namespace ngraph ...@@ -65,10 +71,9 @@ namespace ngraph
std::set<std::string> unknown_operator_types; std::set<std::string> unknown_operator_types;
for (const auto& node_proto : m_graph_proto->node()) for (const auto& node_proto : m_graph_proto->node())
{ {
auto it = m_opset->find(node_proto.op_type()); if (!m_model->is_operator_available(node_proto))
if (it == std::end(*m_opset))
{ {
unknown_operator_types.emplace(node_proto.op_type()); unknown_operator_types.emplace(detail::to_string(node_proto));
} }
} }
...@@ -78,7 +83,7 @@ namespace ngraph ...@@ -78,7 +83,7 @@ namespace ngraph
// Process ONNX graph nodes, convert to nGraph nodes // Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node()) for (const auto& node_proto : m_graph_proto->node())
{ {
m_nodes.emplace_back(node_proto, this); m_nodes.emplace_back(node_proto, *this);
const Node& node{m_nodes.back()}; const Node& node{m_nodes.back()};
NodeVector ng_nodes{node.get_ng_nodes()}; NodeVector ng_nodes{node.get_ng_nodes()};
for (int i = 0; i < ng_nodes.size(); i++) for (int i = 0; i < ng_nodes.size(); i++)
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/op/parameter_vector.hpp" #include "ngraph/op/parameter_vector.hpp"
#include "model.hpp"
#include "operator_set.hpp" #include "operator_set.hpp"
#include "value_info.hpp" #include "value_info.hpp"
...@@ -33,7 +34,7 @@ namespace ngraph ...@@ -33,7 +34,7 @@ namespace ngraph
class Graph class Graph
{ {
public: public:
explicit Graph(const onnx::GraphProto& proto, const OperatorSet& opset); Graph(const onnx::GraphProto& proto, const Model& model);
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
...@@ -47,7 +48,7 @@ namespace ngraph ...@@ -47,7 +48,7 @@ namespace ngraph
const std::string& get_name() const { return m_graph_proto->name(); } const std::string& get_name() const { return m_graph_proto->name(); }
NodeVector make_ng_nodes(const Node& node) const NodeVector make_ng_nodes(const Node& node) const
{ {
return m_opset->at(node.op_type())(node); return m_model->get_operator(node.op_type(), node.domain())(node);
} }
private: private:
...@@ -58,7 +59,7 @@ namespace ngraph ...@@ -58,7 +59,7 @@ namespace ngraph
op::ParameterVector m_parameters; op::ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache; std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers; std::map<std::string, Tensor> m_initializers;
const OperatorSet* m_opset; const Model* m_model;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "model.hpp" #include "model.hpp"
#include "ops_bridge.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -25,15 +26,49 @@ namespace ngraph ...@@ -25,15 +26,49 @@ namespace ngraph
Model::Model(const onnx::ModelProto& model_proto) Model::Model(const onnx::ModelProto& model_proto)
: m_model_proto{&model_proto} : m_model_proto{&model_proto}
{ {
// Walk through the elements of opset_import field and register operator sets
// for each domain. An exception UnknownDomain() will raise if the domain is
// unknown or invalid.
for (const auto& id : m_model_proto->opset_import()) for (const auto& id : m_model_proto->opset_import())
{ {
// onnx.proto(.3): the empty string ("") or absence of this field implies m_opset.emplace(id.domain(),
// the operator set that is defined as part of the ONNX specification. OperatorsBridge::get_operator_set(
if (id.domain().empty()) id.version(), (id.domain() == "ai.onnx" ? "" : id.domain())));
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
const auto dm = m_opset.find("");
if (dm == std::end(m_opset))
{
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, ""));
}
}
const Operator& Model::get_operator(const std::string& name,
const std::string& domain) const
{
const auto dm = m_opset.find(domain);
if (dm == std::end(m_opset))
{ {
m_opset_version = id.version(); throw error::UnknownDomain{domain};
} }
const auto op = dm->second.find(name);
if (op == std::end(dm->second))
{
throw error::UnknownOperator{name, domain};
}
return op->second;
}
bool Model::is_operator_available(const onnx::NodeProto& node_proto) const
{
const auto dm = m_opset.find(node_proto.domain());
if (dm == std::end(m_opset))
{
return false;
} }
const auto op = dm->second.find(node_proto.op_type());
return (op != std::end(dm->second));
} }
} // namespace onnx_import } // namespace onnx_import
......
...@@ -16,9 +16,13 @@ ...@@ -16,9 +16,13 @@
#pragma once #pragma once
#include <onnx-ml.pb.h>
#include <ostream> #include <ostream>
#include <string>
#include <unordered_map>
#include <onnx-ml.pb.h> #include "operator_set.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,11 +34,11 @@ namespace ngraph ...@@ -30,11 +34,11 @@ namespace ngraph
Model() = delete; Model() = delete;
explicit Model(const onnx::ModelProto& model_proto); explicit Model(const onnx::ModelProto& model_proto);
Model(Model&&) noexcept = default;
Model(const Model&) = default; Model(const Model&) = default;
Model(Model&&) = default;
Model& operator=(Model&&) noexcept = delete;
Model& operator=(const Model&) = delete; Model& operator=(const Model&) = delete;
Model& operator=(Model&&) = delete;
const std::string& get_producer_name() const { return m_model_proto->producer_name(); } const std::string& get_producer_name() const { return m_model_proto->producer_name(); }
const onnx::GraphProto& get_graph() const { return m_model_proto->graph(); } const onnx::GraphProto& get_graph() const { return m_model_proto->graph(); }
...@@ -44,10 +48,23 @@ namespace ngraph ...@@ -44,10 +48,23 @@ namespace ngraph
return m_model_proto->producer_version(); return m_model_proto->producer_version();
} }
std::int64_t get_opset_version() const { return m_opset_version; } /// \brief Access an operator object by its type name and domain name
/// The function will return the operator object if it exists, or report an error
/// in case of domain or operator absence.
/// \param name type name of the operator object,
/// \param domain domain name of the operator object.
/// \return Reference to the operator object.
/// \throw error::UnknownDomain there is no operator set defined for the given domain,
/// \throw error::UnknownOperator the given operator type name does not exist in operator set.
const Operator& get_operator(const std::string& name, const std::string& domain) const;
/// \brief Check availability of operator base on NodeProto.
/// \return `true` if the operator is available, otherwise it returns `false`.
bool is_operator_available(const onnx::NodeProto& node_proto) const;
private: private:
const onnx::ModelProto* m_model_proto; const onnx::ModelProto* m_model_proto;
std::int64_t m_opset_version{ONNX_OPSET_VERSION}; std::unordered_map<std::string, OperatorSet> m_opset;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Model& model) inline std::ostream& operator<<(std::ostream& outs, const Model& model)
......
...@@ -14,15 +14,110 @@ ...@@ -14,15 +14,110 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "node.hpp" #include <onnx-ml.pb.h>
#include "attribute.hpp"
#include "graph.hpp" #include "graph.hpp"
#include "node.hpp"
#include "tensor.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
NodeVector Node::get_ng_nodes() const { return m_graph->make_ng_nodes(*this); } class Node::Impl
NodeVector Node::get_ng_inputs() const {
public:
Impl() = delete;
Impl(const onnx::NodeProto& node_proto, const Graph& graph)
: m_node_proto{&node_proto}
, m_graph{&graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
{
}
const std::vector<Attribute>& attributes() const;
NodeVector get_ng_nodes(const Node& node) const;
NodeVector get_ng_inputs() const;
const std::string& domain() const;
const std::string& op_type() const;
const std::string& name() const;
const std::string& description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
template <typename T>
T get_attribute_value(const std::string& name, T default_value) const;
template <typename T>
T get_attribute_value(const std::string& name) const;
const onnx::NodeProto& node_proto() const;
const Graph& graph() const;
private:
const onnx::NodeProto* m_node_proto;
const Graph* m_graph;
std::vector<Attribute> m_attributes;
std::vector<std::reference_wrapper<const std::string>> m_output_names;
mutable std::string m_description;
};
const onnx::NodeProto& Node::Impl::node_proto() const { return *m_node_proto; }
const Graph& Node::Impl::graph() const { return *m_graph; }
const std::vector<Attribute>& Node::Impl::attributes() const { return m_attributes; }
const std::string& Node::Impl::domain() const { return m_node_proto->domain(); }
const std::string& Node::Impl::op_type() const { return m_node_proto->op_type(); }
const std::string& Node::Impl::name() const { return m_node_proto->name(); }
const std::vector<std::reference_wrapper<const std::string>>&
Node::Impl::get_output_names() const
{
return m_output_names;
}
const std::string& Node::Impl::output(int index) const
{
return m_node_proto->output(index);
}
template <typename T>
T Node::Impl::get_attribute_value(const std::string& name, T default_value) const
{
auto it = std::find_if(
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
return attribute.get_name() == name;
});
if (it == std::end(m_attributes))
{
return std::forward<T>(default_value);
}
return it->template get_value<T>();
}
template <typename T>
T Node::Impl::get_attribute_value(const std::string& name) const
{
auto it = std::find_if(
std::begin(m_attributes), std::end(m_attributes), [&](const Attribute& attribute) {
return attribute.get_name() == name;
});
if (it == std::end(m_attributes))
{
throw error::node::UnknownAttribute{this->name(), name};
}
return it->template get_value<T>();
}
NodeVector Node::Impl::get_ng_nodes(const Node& node) const
{
return m_graph->make_ng_nodes(node);
}
NodeVector Node::Impl::get_ng_inputs() const
{ {
NodeVector result; NodeVector result;
for (const auto& name : m_node_proto->input()) for (const auto& name : m_node_proto->input())
...@@ -32,20 +127,224 @@ namespace ngraph ...@@ -32,20 +127,224 @@ namespace ngraph
return result; return result;
} }
std::string Node::get_description() const const std::string& Node::Impl::description() const
{
if (m_description.empty())
{ {
if (!get_name().empty()) if (!name().empty())
{ {
return get_name(); m_description = name();
} }
else
std::stringstream stream; {
for (std::size_t index = 0; index < m_output_names.size(); ++index) for (std::size_t index = 0; index < m_output_names.size(); ++index)
{ {
stream << (index != 0 ? ", " : ""); m_description += (index != 0 ? ", " : "") + m_output_names.at(index).get();
stream << m_output_names.at(index).get(); }
}
}
return m_description;
}
Node::Node(const onnx::NodeProto& node_proto, const Graph& graph)
: m_pimpl{new Impl{node_proto, graph}, [](Impl* impl) { delete impl; }}
{
}
Node::Node(Node&& other) noexcept
: m_pimpl{std::move(other.m_pimpl)}
{
}
Node::Node(const Node& other)
: m_pimpl{new Impl{other.m_pimpl->node_proto(), other.m_pimpl->graph()},
[](Impl* impl) { delete impl; }}
{
} }
return stream.str();
NodeVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); }
NodeVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); }
const std::string& Node::domain() const { return m_pimpl->domain(); }
const std::string& Node::op_type() const { return m_pimpl->op_type(); }
const std::string& Node::get_description() const { return m_pimpl->description(); }
const std::string& Node::get_name() const { return m_pimpl->name(); }
const std::vector<std::reference_wrapper<const std::string>>& Node::get_output_names() const
{
return m_pimpl->get_output_names();
}
const std::string& Node::output(int index) const { return m_pimpl->output(index); }
template <>
float Node::get_attribute_value(const std::string& name, float default_value) const
{
return m_pimpl->template get_attribute_value<float>(name, default_value);
}
template <>
double Node::get_attribute_value(const std::string& name, double default_value) const
{
return m_pimpl->template get_attribute_value<double>(name, default_value);
}
template <>
std::int64_t Node::get_attribute_value(const std::string& name,
std::int64_t default_value) const
{
return m_pimpl->template get_attribute_value<std::int64_t>(name, default_value);
}
template <>
std::string Node::get_attribute_value(const std::string& name,
std::string default_value) const
{
return m_pimpl->template get_attribute_value<std::string>(name,
std::move(default_value));
}
template <>
Tensor Node::get_attribute_value(const std::string& name, Tensor default_value) const
{
return m_pimpl->template get_attribute_value<Tensor>(name, std::move(default_value));
}
template <>
Graph Node::get_attribute_value(const std::string& name, Graph default_value) const
{
return m_pimpl->template get_attribute_value<Graph>(name, std::move(default_value));
}
template <>
std::vector<float> Node::get_attribute_value(const std::string& name,
std::vector<float> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<float>>(
name, std::move(default_value));
}
template <>
std::vector<double> Node::get_attribute_value(const std::string& name,
std::vector<double> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<double>>(
name, std::move(default_value));
}
template <>
std::vector<std::int64_t>
Node::get_attribute_value(const std::string& name,
std::vector<std::int64_t> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<std::int64_t>>(
name, std::move(default_value));
}
template <>
std::vector<std::size_t>
Node::get_attribute_value(const std::string& name,
std::vector<std::size_t> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<std::size_t>>(
name, std::move(default_value));
}
template <>
std::vector<Tensor> Node::get_attribute_value(const std::string& name,
std::vector<Tensor> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<Tensor>>(
name, std::move(default_value));
}
template <>
std::vector<Graph> Node::get_attribute_value(const std::string& name,
std::vector<Graph> default_value) const
{
return m_pimpl->template get_attribute_value<std::vector<Graph>>(
name, std::move(default_value));
}
template <>
float Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<float>(name);
}
template <>
double Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<double>(name);
}
template <>
std::int64_t Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::int64_t>(name);
}
template <>
std::size_t Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::size_t>(name);
}
template <>
std::string Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::string>(name);
}
template <>
Tensor Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<Tensor>(name);
}
template <>
Graph Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<Graph>(name);
}
template <>
std::vector<float> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<float>>(name);
}
template <>
std::vector<double> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<double>>(name);
}
template <>
std::vector<std::int64_t> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<std::int64_t>>(name);
}
template <>
std::vector<std::size_t> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<std::size_t>>(name);
}
template <>
std::vector<std::string> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<std::string>>(name);
}
template <>
std::vector<Tensor> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<Tensor>>(name);
}
template <>
std::vector<Graph> Node::get_attribute_value(const std::string& name) const
{
return m_pimpl->template get_attribute_value<std::vector<Graph>>(name);
} }
} // namespace onnx_import } // namespace onnx_import
......
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
#include <string> #include <string>
#include <onnx-ml.pb.h> #include "ngraph/except.hpp"
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "attribute.hpp" namespace onnx
#include "tensor.hpp" {
// forward declaration
class NodeProto;
}
namespace ngraph namespace ngraph
{ {
...@@ -52,70 +54,41 @@ namespace ngraph ...@@ -52,70 +54,41 @@ namespace ngraph
{ {
public: public:
Node() = delete; Node() = delete;
Node(const onnx::NodeProto& node_proto, const Graph* graph) Node(const onnx::NodeProto& node_proto, const Graph& graph);
: m_node_proto{&node_proto}
, m_graph{graph}
, m_attributes{std::begin(node_proto.attribute()), std::end(node_proto.attribute())}
, m_output_names{std::begin(node_proto.output()), std::end(node_proto.output())}
{
}
Node(Node&&) noexcept = default; Node(Node&&) noexcept;
Node(const Node&) = default; Node(const Node&);
Node& operator=(Node&&) noexcept = delete; Node& operator=(Node&&) noexcept = delete;
Node& operator=(const Node&) = delete; Node& operator=(const Node&) = delete;
const std::vector<Attribute>& attributes() const { return m_attributes; }
NodeVector get_ng_nodes() const;
NodeVector get_ng_inputs() const; NodeVector get_ng_inputs() const;
NodeVector get_ng_nodes() const;
const std::string& domain() const;
const std::string& op_type() const;
const std::string& get_name() const;
const std::string& domain() const { return m_node_proto->domain(); } /// \brief Describe the ONNX Node to make debugging graphs easier
const std::string& op_type() const { return m_node_proto->op_type(); }
const std::string& get_name() const { return m_node_proto->name(); }
/// @brief Describe the ONNX Node to make debugging graphs easier
/// Function will return the Node's name if it has one, or the names of its outputs. /// Function will return the Node's name if it has one, or the names of its outputs.
/// \return Description of Node /// \return Description of Node
std::string get_description() const; const std::string& get_description() const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
const std::string& output(int index) const;
const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const
{
return m_output_names;
}
const std::string& output(int index) const { return m_node_proto->output(index); }
template <typename T> template <typename T>
T get_attribute_value(const std::string& name, T default_value) const T get_attribute_value(const std::string& name, T default_value) const;
{
auto it = std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; });
if (it == std::end(m_attributes))
{
return default_value;
}
return it->template get_value<T>();
}
template <typename T> template <typename T>
T get_attribute_value(const std::string& name) const T get_attribute_value(const std::string& name) const;
{
auto it = std::find_if(
std::begin(m_attributes),
std::end(m_attributes),
[&](const Attribute& attribute) { return attribute.get_name() == name; });
if (it == std::end(m_attributes))
{
throw error::node::UnknownAttribute{get_name(), name};
}
return it->template get_value<T>();
}
private: private:
const onnx::NodeProto* m_node_proto; class Impl;
const Graph* m_graph; // In this case we need custom deleter, because Impl is an incomplete
std::vector<Attribute> m_attributes; // type. Node's are elements of std::vector. Without custom deleter
std::vector<std::reference_wrapper<const std::string>> m_output_names; // compilation fails; the compiler is unable to parameterize an allocator's
// default deleter due to incomple type.
std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Node& node) inline std::ostream& operator<<(std::ostream& outs, const Node& node)
......
...@@ -22,13 +22,12 @@ ...@@ -22,13 +22,12 @@
#include "ngraph/node_vector.hpp" #include "ngraph/node_vector.hpp"
#include "node.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
// Forward declaration
class Node;
using Operator = std::function<NodeVector(const Node&)>; using Operator = std::function<NodeVector(const Node&)>;
using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>; using OperatorSet = std::unordered_map<std::string, std::reference_wrapper<const Operator>>;
......
...@@ -61,8 +61,7 @@ namespace ngraph ...@@ -61,8 +61,7 @@ namespace ngraph
} }
std::vector<std::shared_ptr<Function>> output_functions; std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto}; Model model{model_proto};
Graph graph{model_proto.graph(), Graph graph{model_proto.graph(), model};
OperatorsBridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs()) for (const auto& output : graph.get_outputs())
{ {
output_functions.emplace_back(std::make_shared<Function>( output_functions.emplace_back(std::make_shared<Function>(
...@@ -91,6 +90,14 @@ namespace ngraph ...@@ -91,6 +90,14 @@ namespace ngraph
return load_onnx_model(path).front(); return load_onnx_model(path).front();
} }
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{
OperatorsBridge::register_operator(name, version, domain, std::move(fn));
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -21,10 +21,18 @@ ...@@ -21,10 +21,18 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "core/operator_set.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
// Registers ONNX custom operator
void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn);
// Convert on ONNX model to a vector of nGraph Functions (input stream) // Convert on ONNX model to a vector of nGraph Functions (input stream)
std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream&); std::vector<std::shared_ptr<Function>> load_onnx_model(std::istream&);
......
...@@ -87,134 +87,51 @@ namespace ngraph ...@@ -87,134 +87,51 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
const OperatorSet& OperatorsBridge::get_operator_set_version_1() const namespace detail
{ {
static OperatorSet operator_set; const Operator& find(const std::string& name,
if (operator_set.empty()) std::int64_t version,
const std::string& domain,
const std::map<std::int64_t, Operator>& map)
{ {
for (const auto& op : m_map) while (version > 0)
{ {
for (const auto& it : op.second) const auto it = map.find(version--);
if (it != std::end(map))
{ {
if (it.first == 1) return it->second;
{
operator_set.emplace(op.first, it.second);
}
} }
} }
throw error::UnsupportedVersion{name, version, domain};
} }
return operator_set;
} }
const OperatorSet& OperatorsBridge::get_operator_set_version_2() const void OperatorsBridge::_register_operator(const std::string& name,
{ std::int64_t version,
static OperatorSet operator_set; const std::string& domain,
if (operator_set.empty()) Operator fn)
{ {
operator_set = get_operator_set_version_1(); m_map[domain][name].emplace(version, std::move(fn));
}
return operator_set;
} }
const OperatorSet& OperatorsBridge::get_operator_set_version_3() const OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version,
const std::string& domain)
{ {
static OperatorSet operator_set; OperatorSet result;
if (operator_set.empty()) auto dm = m_map.find(domain);
if (dm == std::end(m_map))
{ {
operator_set = get_operator_set_version_2(); throw error::UnknownDomain{domain};
}
return operator_set;
} }
for (const auto& op : dm->second)
const OperatorSet& OperatorsBridge::get_operator_set_version_4() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_3();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_5() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_4();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_6() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_5();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_7() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_6();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_8() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_7();
}
return operator_set;
}
const OperatorSet& OperatorsBridge::get_operator_set_version_9() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
operator_set = get_operator_set_version_8();
}
return operator_set;
}
#define OPERATOR_SET_NAME(version_) get_operator_set_version_##version_()
#define GET_OPERATOR_SET(version_) \
case version_: \
return OPERATOR_SET_NAME(version_)
#define OPERATOR_SET_NAME_HELPER(version_) OPERATOR_SET_NAME(version_)
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const OperatorSet& OperatorsBridge::get_operator_set_version(std::int64_t version) const
{
switch (version)
{ {
GET_OPERATOR_SET(1); result.emplace(op.first, detail::find(op.first, version, domain, op.second));
GET_OPERATOR_SET(2);
GET_OPERATOR_SET(3);
GET_OPERATOR_SET(4);
GET_OPERATOR_SET(5);
GET_OPERATOR_SET(6);
GET_OPERATOR_SET(7);
GET_OPERATOR_SET(8);
GET_OPERATOR_SET(9);
default: DEFAULT_OPERATOR_SET();
} }
return result;
} }
#define REGISTER_OPERATOR(name_, version_, fn_) \ #define REGISTER_OPERATOR(name_, ver_, fn_) \
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1)) m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
OperatorsBridge::OperatorsBridge() OperatorsBridge::OperatorsBridge()
{ {
......
...@@ -33,16 +33,27 @@ namespace ngraph ...@@ -33,16 +33,27 @@ namespace ngraph
{ {
struct UnknownOperator : ngraph_error struct UnknownOperator : ngraph_error
{ {
explicit UnknownOperator(const std::string& op_type) UnknownOperator(const std::string& name, const std::string& domain)
: ngraph_error{"unknown operator: \"" + op_type + "\""} : ngraph_error{(domain.empty() ? "" : domain + ".") + name}
{
}
};
struct UnknownDomain : ngraph_error
{
explicit UnknownDomain(const std::string& domain)
: ngraph_error{domain}
{ {
} }
}; };
struct UnsupportedVersion : ngraph_error struct UnsupportedVersion : ngraph_error
{ {
explicit UnsupportedVersion(std::int64_t version) UnsupportedVersion(const std::string& name,
: ngraph_error{"unsupported operator set version: " + std::to_string(version)} std::int64_t version,
const std::string& domain)
: ngraph_error{(domain.empty() ? "" : domain + ".") + name + ":" +
std::to_string(version)}
{ {
} }
}; };
...@@ -57,32 +68,37 @@ namespace ngraph ...@@ -57,32 +68,37 @@ namespace ngraph
OperatorsBridge(OperatorsBridge&&) = delete; OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete; OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version) static OperatorSet get_operator_set(std::int64_t version, const std::string& domain)
{
return instance()._get_operator_set(version, domain);
}
static void register_operator(const std::string& name,
std::int64_t version,
const std::string& domain,
Operator fn)
{ {
return instance().get_operator_set_version(version); instance()._register_operator(name, version, domain, std::move(fn));
} }
private: private:
std::unordered_map<std::string, std::map<std::int64_t, Operator>> m_map; std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
m_map;
OperatorsBridge(); OperatorsBridge();
static const OperatorsBridge& instance() static OperatorsBridge& instance()
{ {
static OperatorsBridge instance; static OperatorsBridge instance;
return instance; return instance;
} }
const OperatorSet& get_operator_set_version_1() const; void _register_operator(const std::string& name,
const OperatorSet& get_operator_set_version_2() const; std::int64_t version,
const OperatorSet& get_operator_set_version_3() const; const std::string& domain,
const OperatorSet& get_operator_set_version_4() const; Operator fn);
const OperatorSet& get_operator_set_version_5() const; OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
const OperatorSet& get_operator_set_version_6() const;
const OperatorSet& get_operator_set_version_7() const;
const OperatorSet& get_operator_set_version_8() const;
const OperatorSet& get_operator_set_version_9() const;
const OperatorSet& get_operator_set_version(std::int64_t version) const;
}; };
} // namespace onnx_import } // namespace onnx_import
......
...@@ -1297,3 +1297,39 @@ TEST(onnx, model_unsupported_op) ...@@ -1297,3 +1297,39 @@ TEST(onnx, model_unsupported_op)
FAIL() << "Expected ngraph::ngraph_error"; FAIL() << "Expected ngraph::ngraph_error";
} }
} }
TEST(onnx, model_custom_op)
{
onnx_import::register_operator(
"AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator.onnx"));
Inputs inputs{{1, 2, 3, 4}};
Outputs expected_outputs{{3, 6, 9, 12}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx, model_custom_op_default_domain)
{
onnx_import::register_operator(
"AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
auto function = onnx_import::import_onnx_function(
file_util::path_join(SERIALIZED_ZOO, "onnx/custom_operator_default_domain.onnx"));
Inputs inputs{{1, 2, 3, 4}};
Outputs expected_outputs{{3, 6, 9, 12}};
Outputs outputs{execute(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment