Commit 49a32b14 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

[ONNX] Expose API to check whether ONNX Op is supported. (#2299)

parent 1db9707f
...@@ -99,6 +99,14 @@ namespace ngraph ...@@ -99,6 +99,14 @@ namespace ngraph
return op_list; return op_list;
} }
bool is_operator_supported(const std::string& op_name,
std::int64_t version,
const std::string& domain)
{
return OperatorsBridge::is_operator_registered(
op_name, version, domain == "ai.onnx" ? "" : domain);
}
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -52,6 +52,18 @@ namespace ngraph ...@@ -52,6 +52,18 @@ namespace ngraph
std::set<std::string> get_supported_operators(std::int64_t version, std::set<std::string> get_supported_operators(std::int64_t version,
const std::string& domain); const std::string& domain);
/// \brief Determines whether ONNX operator is supported.
///
/// \param[in] op_name The ONNX operator name.
/// \param[in] version The ONNX operator set version.
/// \param[in] domain The domain the ONNX operator is registered to.
///
/// \return True if operator is supported, False otherwise.
///
bool is_operator_supported(const std::string& op_name,
std::int64_t version,
const std::string& domain = "ai.onnx");
/// \brief Convert an ONNX model to nGraph function /// \brief Convert an ONNX model to nGraph function
/// The function translated serialized ONNX model to nGraph function. The serialized /// The function translated serialized ONNX model to nGraph function. The serialized
/// ONNX model is read from input stream. /// ONNX model is read from input stream.
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <unordered_map> #include <unordered_map>
#include "core/attribute.hpp" #include "core/attribute.hpp"
#include "ngraph/log.hpp"
#include "op/abs.hpp" #include "op/abs.hpp"
#include "op/acos.hpp" #include "op/acos.hpp"
#include "op/add.hpp" #include "op/add.hpp"
...@@ -102,20 +103,19 @@ namespace ngraph ...@@ -102,20 +103,19 @@ namespace ngraph
{ {
namespace detail namespace detail
{ {
const Operator& find(const std::string& name, const std::map<std::int64_t, Operator>::const_iterator
std::int64_t version, find(std::int64_t version, const std::map<std::int64_t, Operator>& map)
const std::string& domain,
const std::map<std::int64_t, Operator>& map)
{ {
std::map<std::int64_t, Operator>::const_iterator it{};
while (version > 0) while (version > 0)
{ {
const auto it = map.find(version--); it = map.find(version--);
if (it != std::end(map)) if (it != std::end(map))
{ {
return it->second; return it;
} }
} }
throw error::UnsupportedVersion{name, version, domain}; return it;
} }
} }
...@@ -136,13 +136,51 @@ namespace ngraph ...@@ -136,13 +136,51 @@ namespace ngraph
{ {
throw error::UnknownDomain{domain}; throw error::UnknownDomain{domain};
} }
if (version > OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION)
{
NGRAPH_WARN << "Currently operator set version: " << version << " is unsupported."
<< " Falling back to: "
<< OperatorsBridge::LATEST_SUPPORTED_OPSET_VERSION;
}
for (const auto& op : dm->second) for (const auto& op : dm->second)
{ {
result.emplace(op.first, detail::find(op.first, version, domain, op.second)); const auto& it = detail::find(version, op.second);
if (it == std::end(op.second))
{
throw error::UnsupportedVersion{op.first, version, domain};
}
result.emplace(op.first, it->second);
} }
return result; return result;
} }
bool OperatorsBridge::_is_operator_registered(const std::string& name,
std::int64_t version,
const std::string& domain)
{
// search for domain
auto dm_map = m_map.find(domain);
if (dm_map == std::end(m_map))
{
return false;
}
// search for name
auto op_map = dm_map->second.find(name);
if (op_map == std::end(dm_map->second))
{
return false;
}
if (detail::find(version, op_map->second) != std::end(op_map->second))
{
return true;
}
else
{
return false;
}
}
#define REGISTER_OPERATOR(name_, ver_, fn_) \ #define REGISTER_OPERATOR(name_, ver_, fn_) \
m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1)) m_map[""][name_].emplace(ver_, std::bind(op::set_##ver_::fn_, std::placeholders::_1))
......
...@@ -62,6 +62,8 @@ namespace ngraph ...@@ -62,6 +62,8 @@ namespace ngraph
class OperatorsBridge class OperatorsBridge
{ {
public: public:
static constexpr const int LATEST_SUPPORTED_OPSET_VERSION = ONNX_OPSET_VERSION;
OperatorsBridge(const OperatorsBridge&) = delete; OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete; OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete; OperatorsBridge(OperatorsBridge&&) = delete;
...@@ -80,6 +82,13 @@ namespace ngraph ...@@ -80,6 +82,13 @@ namespace ngraph
instance()._register_operator(name, version, domain, std::move(fn)); instance()._register_operator(name, version, domain, std::move(fn));
} }
static bool is_operator_registered(const std::string& name,
std::int64_t version,
const std::string& domain)
{
return instance()._is_operator_registered(name, version, domain);
}
private: private:
std::unordered_map<std::string, std::unordered_map<std::string,
std::unordered_map<std::string, std::map<std::int64_t, Operator>>> std::unordered_map<std::string, std::map<std::int64_t, Operator>>>
...@@ -98,6 +107,9 @@ namespace ngraph ...@@ -98,6 +107,9 @@ namespace ngraph
const std::string& domain, const std::string& domain,
Operator fn); Operator fn);
OperatorSet _get_operator_set(std::int64_t version, const std::string& domain); OperatorSet _get_operator_set(std::int64_t version, const std::string& domain);
bool _is_operator_registered(const std::string& name,
std::int64_t version,
const std::string& domain);
}; };
} // namespace onnx_import } // namespace onnx_import
......
...@@ -1678,3 +1678,34 @@ TEST(onnx, model_argmin_int32) ...@@ -1678,3 +1678,34 @@ TEST(onnx, model_argmin_int32)
execute<std::int32_t, std::int64_t>(function, inputs, "INTERPRETER")}; execute<std::int32_t, std::int64_t>(function, inputs, "INTERPRETER")};
EXPECT_TRUE(test::all_close(expected_output.front(), outputs.front())); EXPECT_TRUE(test::all_close(expected_output.front(), outputs.front()));
} }
TEST(onnx, model_is_op_supported)
{
// Simple case
EXPECT_TRUE(onnx_import::is_operator_supported("Sum", 1, "ai.onnx"));
// With fallback
EXPECT_TRUE(onnx_import::is_operator_supported("Sum", 100, "ai.onnx"));
// Different opset versions
EXPECT_TRUE(onnx_import::is_operator_supported("Add", 1, "ai.onnx"));
EXPECT_TRUE(onnx_import::is_operator_supported("Add", 7, "ai.onnx"));
// Default domain name
EXPECT_TRUE(onnx_import::is_operator_supported("Sum", 1));
// Unregistered operator
EXPECT_FALSE(onnx_import::is_operator_supported("DummyOp", 1));
EXPECT_FALSE(onnx_import::is_operator_supported("DummyOp", 1, "ai.onnx"));
EXPECT_FALSE(onnx_import::is_operator_supported("DummyOp", 10, "ai.onnx"));
// Operator with bad domain name
EXPECT_FALSE(onnx_import::is_operator_supported("Sum", 1, "bad.domain"));
// Registered custom operator
onnx_import::register_operator(
"AddQ", 1, "com.intel.ai", [](const onnx_import::Node& node) -> NodeVector {
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
});
EXPECT_TRUE(onnx_import::is_operator_supported("AddQ", 1, "com.intel.ai"));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment