Commit 1f350378 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Enable deselected supported opset domain when needed. (#2350)

parent 676f8d36
...@@ -22,7 +22,7 @@ namespace ngraph ...@@ -22,7 +22,7 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
std::vector<Graph> Attribute::get_graph_array(const Model& model) const std::vector<Graph> Attribute::get_graph_array(Model& model) const
{ {
std::vector<Graph> result; std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs()) for (const auto& graph : m_attribute_proto->graphs())
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
return result; return result;
} }
Graph Attribute::get_graph(const Model& model) const Graph Attribute::get_graph(Model& model) const
{ {
return Graph{m_attribute_proto->g(), model}; return Graph{m_attribute_proto->g(), model};
} }
......
...@@ -278,7 +278,7 @@ namespace ngraph ...@@ -278,7 +278,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); } float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); } int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); } const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph(const Model&) const; Graph get_graph(Model&) const;
std::vector<Tensor> get_tensor_array() const std::vector<Tensor> get_tensor_array() const
{ {
...@@ -303,7 +303,7 @@ namespace ngraph ...@@ -303,7 +303,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())}; std::end(m_attribute_proto->strings())};
} }
std::vector<Graph> get_graph_array(const Model&) const; std::vector<Graph> get_graph_array(Model&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const /* explicit */ operator onnx::AttributeProto_AttributeType() const
{ {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <functional>
#include <set> #include <set>
#include "graph.hpp" #include "graph.hpp"
...@@ -25,26 +26,40 @@ namespace ngraph ...@@ -25,26 +26,40 @@ namespace ngraph
{ {
namespace detail namespace detail
{ {
std::string to_string(const std::set<std::string>& set) static std::string to_string(
const std::map<std::string, std::reference_wrapper<const onnx::NodeProto>>& map)
{ {
std::string result; std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it) for (auto it = std::begin(map); it != std::end(map); ++it)
{ {
result += (it != std::begin(set) ? ", " : "") + *it; result += (it != std::begin(map) ? ", " : "") + it->first;
} }
return result; return result;
} }
inline std::string to_string(const onnx::NodeProto& node_proto) static std::string get_node_domain(const onnx::NodeProto& node_proto)
{ {
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") + return (node_proto.domain().empty() ? "" : node_proto.domain());
node_proto.op_type();
} }
/// \brief Gets the operator represented by provided node unique identificator.
///
/// \param[in] node_proto The node protobuf representation object.
///
/// \note The operator is uniquely identified by the tuple (domain, op_type,
/// since_version). The first two elements are stored in NodeProto object,
/// thus we use only them.
///
/// \return The unique identificator.
///
static std::string get_op_domain_and_name(const onnx::NodeProto& node_proto)
{
std::string domain = get_node_domain(node_proto);
return (domain.empty() ? "" : domain + ".") + node_proto.op_type();
} }
} // namespace detail
Graph::Graph(const onnx::GraphProto& graph_proto, Graph::Graph(const onnx::GraphProto& graph_proto, Model& model, const Weights& weights)
const Model& model,
const Weights& weights)
: m_graph_proto{&graph_proto} : m_graph_proto{&graph_proto}
, m_model{&model} , m_model{&model}
{ {
...@@ -70,17 +85,34 @@ namespace ngraph ...@@ -70,17 +85,34 @@ namespace ngraph
} }
// Verify that ONNX graph contains only nodes of available operator types // Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types; std::map<std::string, std::reference_wrapper<const onnx::NodeProto>> unknown_operators;
for (const auto& node_proto : m_graph_proto->node()) for (const auto& node_proto : m_graph_proto->node())
{ {
if (!m_model->is_operator_available(node_proto)) if (!m_model->is_operator_available(node_proto))
{ {
unknown_operator_types.emplace(detail::to_string(node_proto)); unknown_operators.emplace(detail::get_op_domain_and_name(node_proto),
node_proto);
// Try adding missing domain
m_model->enable_opset_domain(detail::get_node_domain(node_proto));
}
}
// Reverify wheter we still have any unavailable operators.
auto it = std::begin(unknown_operators);
while (it != std::end(unknown_operators))
{
if (m_model->is_operator_available(it->second))
{
it = unknown_operators.erase(it);
}
else
{
it++;
} }
} }
NGRAPH_ASSERT(unknown_operator_types.empty()) NGRAPH_ASSERT(unknown_operators.empty()) << "unknown operations: "
<< "unknown operations: " << detail::to_string(unknown_operator_types); << detail::to_string(unknown_operators);
// Process ONNX graph nodes, convert to nGraph nodes // Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node()) for (const auto& node_proto : m_graph_proto->node())
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
class Graph class Graph
{ {
public: public:
Graph(const onnx::GraphProto& proto, const Model& model, const Weights& weights = {}); Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {});
const std::vector<Node>& get_nodes() const { return m_nodes; } const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; } const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
...@@ -59,7 +59,7 @@ namespace ngraph ...@@ -59,7 +59,7 @@ namespace ngraph
ParameterVector m_parameters; ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache; std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers; std::map<std::string, Tensor> m_initializers;
const Model* m_model; Model* m_model;
}; };
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h> #include <onnx-ml.pb.h>
#include "model.hpp" #include "model.hpp"
#include "ngraph/log.hpp"
#include "ops_bridge.hpp" #include "ops_bridge.hpp"
namespace ngraph namespace ngraph
...@@ -33,14 +34,14 @@ namespace ngraph ...@@ -33,14 +34,14 @@ namespace ngraph
{ {
m_opset.emplace(id.domain(), m_opset.emplace(id.domain(),
OperatorsBridge::get_operator_set( OperatorsBridge::get_operator_set(
id.version(), (id.domain() == "ai.onnx" ? "" : id.domain()))); (id.domain() == "ai.onnx" ? "" : id.domain()), id.version()));
} }
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field // onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification. // implies the operator set that is defined as part of the ONNX specification.
const auto dm = m_opset.find(""); const auto dm = m_opset.find("");
if (dm == std::end(m_opset)) if (dm == std::end(m_opset))
{ {
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, "")); m_opset.emplace("", OperatorsBridge::get_operator_set("", ONNX_OPSET_VERSION));
} }
} }
...@@ -71,6 +72,26 @@ namespace ngraph ...@@ -71,6 +72,26 @@ namespace ngraph
return (op != std::end(dm->second)); return (op != std::end(dm->second));
} }
void Model::enable_opset_domain(const std::string& domain)
{
// There is no need to 'update' already enabled domain.
// Since this function may be called only during model import,
// (maybe multiple times) the registered domain opset won't differ
// between subsequent calls.
if (m_opset.find(domain) == std::end(m_opset))
{
OperatorSet opset{OperatorsBridge::get_operator_set(domain)};
if (opset.empty())
{
NGRAPH_WARN << "Couldn't enable domain: " << domain
<< " since it hasn't any registered operators.";
return;
}
m_opset.emplace(domain, opset);
}
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -61,6 +61,15 @@ namespace ngraph ...@@ -61,6 +61,15 @@ namespace ngraph
/// \return `true` if the operator is available, otherwise it returns `false`. /// \return `true` if the operator is available, otherwise it returns `false`.
bool is_operator_available(const onnx::NodeProto& node_proto) const; bool is_operator_available(const onnx::NodeProto& node_proto) const;
/// \brief Enable operators from provided domain to use by this model.
///
/// \note This function makes visible all currently registered in provided domain
/// operators for use in this model.
///
/// \param[in] domain The domain name.
///
void enable_opset_domain(const std::string& domain);
private: private:
const onnx::ModelProto* m_model_proto; const onnx::ModelProto* m_model_proto;
std::unordered_map<std::string, OperatorSet> m_opset; std::unordered_map<std::string, OperatorSet> m_opset;
......
...@@ -90,7 +90,8 @@ namespace ngraph ...@@ -90,7 +90,8 @@ namespace ngraph
std::set<std::string> get_supported_operators(std::int64_t version, std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain) const std::string& domain)
{ {
OperatorSet op_set{OperatorsBridge::get_operator_set(version, domain)}; OperatorSet op_set{
OperatorsBridge::get_operator_set(domain == "ai.onnx" ? "" : domain, version)};
std::set<std::string> op_list{}; std::set<std::string> op_list{};
for (const auto& op : op_set) for (const auto& op : op_set)
{ {
......
...@@ -110,6 +110,11 @@ namespace ngraph ...@@ -110,6 +110,11 @@ namespace ngraph
find(std::int64_t version, const std::map<std::int64_t, Operator>& map) find(std::int64_t version, const std::map<std::int64_t, Operator>& map)
{ {
std::map<std::int64_t, Operator>::const_iterator it{}; std::map<std::int64_t, Operator>::const_iterator it{};
// Get the latest version.
if (version == -1)
{
return map.empty() ? std::end(map) : --std::end(map);
}
while (version > 0) while (version > 0)
{ {
it = map.find(version--); it = map.find(version--);
...@@ -127,23 +132,29 @@ namespace ngraph ...@@ -127,23 +132,29 @@ namespace ngraph
const std::string& domain, const std::string& domain,
Operator fn) Operator fn)
{ {
m_map[domain][name].emplace(version, std::move(fn)); auto result = m_map[domain][name].emplace(version, std::move(fn));
if (result.second)
{
NGRAPH_WARN << "Overwriting existing operator: "
<< domain + "." + name + ":" + std::to_string(version);
}
} }
OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version, OperatorSet OperatorsBridge::_get_operator_set(const std::string& domain,
const std::string& domain) std::int64_t version)
{ {
OperatorSet result; OperatorSet result;
auto dm = m_map.find(domain); auto dm = m_map.find(domain);
if (dm == std::end(m_map)) if (dm == std::end(m_map))
{ {
throw error::UnknownDomain{domain}; throw error::UnknownDomain{domain};
} }
if (version > OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION) if (domain == "" && version > OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION)
{ {
NGRAPH_WARN << "Currently operator set version: " << version << " is unsupported." NGRAPH_WARN << "Currently ONNX operator set version: " << version
<< " Falling back to: " << " is unsupported. Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION; << OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION;
} }
for (const auto& op : dm->second) for (const auto& op : dm->second)
{ {
......
...@@ -62,16 +62,17 @@ namespace ngraph ...@@ -62,16 +62,17 @@ namespace ngraph
class OperatorsBridge class OperatorsBridge
{ {
public: public:
static constexpr const int LATEST_SUPPORTED_OPSET_VERSION = ONNX_OPSET_VERSION; static constexpr const int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;
OperatorsBridge(const OperatorsBridge&) = delete; OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete; OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete; OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete; OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static OperatorSet get_operator_set(std::int64_t version, const std::string& domain) static OperatorSet get_operator_set(const std::string& domain,
std::int64_t version = -1)
{ {
return instance()._get_operator_set(version, domain); return instance()._get_operator_set(domain, version);
} }
static void register_operator(const std::string& name, static void register_operator(const std::string& name,
...@@ -90,6 +91,20 @@ namespace ngraph ...@@ -90,6 +91,20 @@ namespace ngraph
} }
private: private:
// Registered operators structure
// {
// domain_1: {
// op_type_1: {
// version_1: {func_handle},
// version_2: {func_handle},
// ...
// },
// op_type_2: { ... }
// ...
// },
// domain_2: { ... },
// ...
// }
std::unordered_map<std::string, std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>> std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
m_map; m_map;
...@@ -106,7 +121,8 @@ namespace ngraph ...@@ -106,7 +121,8 @@ namespace ngraph
std::int64_t version, std::int64_t version,
const std::string& domain, const std::string& domain,
Operator fn); Operator fn);
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain); OperatorSet _get_operator_set(const std::string& domain, std::int64_t version);
bool _is_operator_registered(const std::string& name, bool _is_operator_registered(const std::string& name,
std::int64_t version, std::int64_t version,
const std::string& domain); const std::string& domain);
......
ONNXnGraphImporter:o

A
BC" CustomAdd: custom.op compute_graphZ
A


Z
B


b
C


B
\ No newline at end of file
...@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize) ...@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
std::runtime_error); std::runtime_error);
} }
TEST(onnx_${BACKEND_NAME}, model_missing_op_domain)
{
onnx_import::register_operator(
"CustomAdd", 1, "custom.op", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
EXPECT_TRUE(onnx_import::is_operator_supported("CustomAdd", 1, "custom.op"));
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/missing_op_domain.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
Outputs expected_output{std::vector<float>{0.f, 2.f, 4.f, 6.f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_top_k) TEST(onnx_${BACKEND_NAME}, model_top_k)
{ {
auto function = auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment