Commit 1f350378 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[ONNX] Enable deselected supported opset domain when needed. (#2350)

parent 676f8d36
......@@ -22,7 +22,7 @@ namespace ngraph
{
namespace onnx_import
{
std::vector<Graph> Attribute::get_graph_array(const Model& model) const
std::vector<Graph> Attribute::get_graph_array(Model& model) const
{
std::vector<Graph> result;
for (const auto& graph : m_attribute_proto->graphs())
......@@ -32,7 +32,7 @@ namespace ngraph
return result;
}
Graph Attribute::get_graph(const Model& model) const
Graph Attribute::get_graph(Model& model) const
{
return Graph{m_attribute_proto->g(), model};
}
......
......@@ -278,7 +278,7 @@ namespace ngraph
float get_float() const { return m_attribute_proto->f(); }
int64_t get_integer() const { return m_attribute_proto->i(); }
const std::string& get_string() const { return m_attribute_proto->s(); }
Graph get_graph(const Model&) const;
Graph get_graph(Model&) const;
std::vector<Tensor> get_tensor_array() const
{
......@@ -303,7 +303,7 @@ namespace ngraph
std::end(m_attribute_proto->strings())};
}
std::vector<Graph> get_graph_array(const Model&) const;
std::vector<Graph> get_graph_array(Model&) const;
/* explicit */ operator onnx::AttributeProto_AttributeType() const
{
......
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <set>
#include "graph.hpp"
......@@ -25,26 +26,40 @@ namespace ngraph
{
namespace detail
{
std::string to_string(const std::set<std::string>& set)
static std::string to_string(
const std::map<std::string, std::reference_wrapper<const onnx::NodeProto>>& map)
{
std::string result;
for (auto it = std::begin(set); it != std::end(set); ++it)
for (auto it = std::begin(map); it != std::end(map); ++it)
{
result += (it != std::begin(set) ? ", " : "") + *it;
result += (it != std::begin(map) ? ", " : "") + it->first;
}
return result;
}
inline std::string to_string(const onnx::NodeProto& node_proto)
static std::string get_node_domain(const onnx::NodeProto& node_proto)
{
return (node_proto.domain().empty() ? "" : node_proto.domain() + ".") +
node_proto.op_type();
return (node_proto.domain().empty() ? "" : node_proto.domain());
}
/// \brief Gets the operator represented by provided node unique identificator.
///
/// \param[in] node_proto The node protobuf representation object.
///
/// \note The operator is uniquely identified by the tuple (domain, op_type,
/// since_version). The first two elements are stored in NodeProto object,
/// thus we use only them.
///
/// \return The unique identificator.
///
static std::string get_op_domain_and_name(const onnx::NodeProto& node_proto)
{
std::string domain = get_node_domain(node_proto);
return (domain.empty() ? "" : domain + ".") + node_proto.op_type();
}
} // namespace detail
Graph::Graph(const onnx::GraphProto& graph_proto,
const Model& model,
const Weights& weights)
Graph::Graph(const onnx::GraphProto& graph_proto, Model& model, const Weights& weights)
: m_graph_proto{&graph_proto}
, m_model{&model}
{
......@@ -70,17 +85,34 @@ namespace ngraph
}
// Verify that ONNX graph contains only nodes of available operator types
std::set<std::string> unknown_operator_types;
std::map<std::string, std::reference_wrapper<const onnx::NodeProto>> unknown_operators;
for (const auto& node_proto : m_graph_proto->node())
{
if (!m_model->is_operator_available(node_proto))
{
unknown_operator_types.emplace(detail::to_string(node_proto));
unknown_operators.emplace(detail::get_op_domain_and_name(node_proto),
node_proto);
// Try adding missing domain
m_model->enable_opset_domain(detail::get_node_domain(node_proto));
}
}
// Reverify wheter we still have any unavailable operators.
auto it = std::begin(unknown_operators);
while (it != std::end(unknown_operators))
{
if (m_model->is_operator_available(it->second))
{
it = unknown_operators.erase(it);
}
else
{
it++;
}
}
NGRAPH_ASSERT(unknown_operator_types.empty())
<< "unknown operations: " << detail::to_string(unknown_operator_types);
NGRAPH_ASSERT(unknown_operators.empty()) << "unknown operations: "
<< detail::to_string(unknown_operators);
// Process ONNX graph nodes, convert to nGraph nodes
for (const auto& node_proto : m_graph_proto->node())
......
......@@ -33,7 +33,7 @@ namespace ngraph
class Graph
{
public:
Graph(const onnx::GraphProto& proto, const Model& model, const Weights& weights = {});
Graph(const onnx::GraphProto& proto, Model& model, const Weights& weights = {});
const std::vector<Node>& get_nodes() const { return m_nodes; }
const std::vector<ValueInfo>& get_inputs() const { return m_inputs; }
......@@ -59,7 +59,7 @@ namespace ngraph
ParameterVector m_parameters;
std::map<std::string, std::shared_ptr<ngraph::Node>> m_ng_node_cache;
std::map<std::string, Tensor> m_initializers;
const Model* m_model;
Model* m_model;
};
inline std::ostream& operator<<(std::ostream& outs, const Graph& graph)
......
......@@ -17,6 +17,7 @@
#include <onnx-ml.pb.h>
#include "model.hpp"
#include "ngraph/log.hpp"
#include "ops_bridge.hpp"
namespace ngraph
......@@ -33,14 +34,14 @@ namespace ngraph
{
m_opset.emplace(id.domain(),
OperatorsBridge::get_operator_set(
id.version(), (id.domain() == "ai.onnx" ? "" : id.domain())));
(id.domain() == "ai.onnx" ? "" : id.domain()), id.version()));
}
// onnx.proto(.3): the empty string ("") for domain or absence of opset_import field
// implies the operator set that is defined as part of the ONNX specification.
const auto dm = m_opset.find("");
if (dm == std::end(m_opset))
{
m_opset.emplace("", OperatorsBridge::get_operator_set(ONNX_OPSET_VERSION, ""));
m_opset.emplace("", OperatorsBridge::get_operator_set("", ONNX_OPSET_VERSION));
}
}
......@@ -71,6 +72,26 @@ namespace ngraph
return (op != std::end(dm->second));
}
void Model::enable_opset_domain(const std::string& domain)
{
// There is no need to 'update' already enabled domain.
// Since this function may be called only during model import,
// (maybe multiple times) the registered domain opset won't differ
// between subsequent calls.
if (m_opset.find(domain) == std::end(m_opset))
{
OperatorSet opset{OperatorsBridge::get_operator_set(domain)};
if (opset.empty())
{
NGRAPH_WARN << "Couldn't enable domain: " << domain
<< " since it hasn't any registered operators.";
return;
}
m_opset.emplace(domain, opset);
}
}
} // namespace onnx_import
} // namespace ngraph
......@@ -61,6 +61,15 @@ namespace ngraph
/// \return `true` if the operator is available, otherwise it returns `false`.
bool is_operator_available(const onnx::NodeProto& node_proto) const;
/// \brief Enable operators from provided domain to use by this model.
///
/// \note This function makes visible all currently registered in provided domain
/// operators for use in this model.
///
/// \param[in] domain The domain name.
///
void enable_opset_domain(const std::string& domain);
private:
const onnx::ModelProto* m_model_proto;
std::unordered_map<std::string, OperatorSet> m_opset;
......
......@@ -90,7 +90,8 @@ namespace ngraph
std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain)
{
OperatorSet op_set{OperatorsBridge::get_operator_set(version, domain)};
OperatorSet op_set{
OperatorsBridge::get_operator_set(domain == "ai.onnx" ? "" : domain, version)};
std::set<std::string> op_list{};
for (const auto& op : op_set)
{
......
......@@ -110,6 +110,11 @@ namespace ngraph
find(std::int64_t version, const std::map<std::int64_t, Operator>& map)
{
std::map<std::int64_t, Operator>::const_iterator it{};
// Get the latest version.
if (version == -1)
{
return map.empty() ? std::end(map) : --std::end(map);
}
while (version > 0)
{
it = map.find(version--);
......@@ -127,23 +132,29 @@ namespace ngraph
const std::string& domain,
Operator fn)
{
m_map[domain][name].emplace(version, std::move(fn));
auto result = m_map[domain][name].emplace(version, std::move(fn));
if (result.second)
{
NGRAPH_WARN << "Overwriting existing operator: "
<< domain + "." + name + ":" + std::to_string(version);
}
}
OperatorSet OperatorsBridge::_get_operator_set(std::int64_t version,
const std::string& domain)
OperatorSet OperatorsBridge::_get_operator_set(const std::string& domain,
std::int64_t version)
{
OperatorSet result;
auto dm = m_map.find(domain);
if (dm == std::end(m_map))
{
throw error::UnknownDomain{domain};
}
if (version > OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION)
if (domain == "" && version > OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION)
{
NGRAPH_WARN << "Currently operator set version: " << version << " is unsupported."
<< " Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION;
NGRAPH_WARN << "Currently ONNX operator set version: " << version
<< " is unsupported. Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_ONNX_OPSET_VERSION;
}
for (const auto& op : dm->second)
{
......
......@@ -62,16 +62,17 @@ namespace ngraph
class OperatorsBridge
{
public:
static constexpr const int LATEST_SUPPORTED_OPSET_VERSION = ONNX_OPSET_VERSION;
static constexpr const int LATEST_SUPPORTED_ONNX_OPSET_VERSION = ONNX_OPSET_VERSION;
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static OperatorSet get_operator_set(std::int64_t version, const std::string& domain)
static OperatorSet get_operator_set(const std::string& domain,
std::int64_t version = -1)
{
return instance()._get_operator_set(version, domain);
return instance()._get_operator_set(domain, version);
}
static void register_operator(const std::string& name,
......@@ -90,6 +91,20 @@ namespace ngraph
}
private:
// Registered operators structure
// {
// domain_1: {
// op_type_1: {
// version_1: {func_handle},
// version_2: {func_handle},
// ...
// },
// op_type_2: { ... }
// ...
// },
// domain_2: { ... },
// ...
// }
std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
m_map;
......@@ -106,7 +121,8 @@ namespace ngraph
std::int64_t version,
const std::string& domain,
Operator fn);
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
OperatorSet _get_operator_set(const std::string& domain, std::int64_t version);
bool _is_operator_registered(const std::string& name,
std::int64_t version,
const std::string& domain);
......
ONNXnGraphImporter:o

A
BC" CustomAdd: custom.op compute_graphZ
A


Z
B


b
C


B
\ No newline at end of file
......@@ -1820,6 +1820,29 @@ TEST(onnx_${BACKEND_NAME}, model_space_to_depth_no_blocksize)
std::runtime_error);
}
TEST(onnx_${BACKEND_NAME}, model_missing_op_domain)
{
onnx_import::register_operator(
"CustomAdd", 1, "custom.op", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
EXPECT_TRUE(onnx_import::is_operator_supported("CustomAdd", 1, "custom.op"));
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/missing_op_domain.onnx"));
Inputs inputs;
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
inputs.emplace_back(std::vector<float>{0.f, 1.f, 2.f, 3.f});
Outputs expected_output{std::vector<float>{0.f, 2.f, 4.f, 6.f}};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_output.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, model_top_k)
{
auto function =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment