Unverified Commit c85ff3b8 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

onnx: flatten operatos bridge hierarchy (#1846)

Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 3167b167
......@@ -62,7 +62,7 @@ namespace ngraph
std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto};
Graph graph{model_proto.graph(),
ops_bridge::get_operator_set(model.get_opset_version())};
OperatorsBridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs())
{
output_functions.emplace_back(std::make_shared<Function>(
......
......@@ -87,69 +87,7 @@ namespace ngraph
{
namespace onnx_import
{
namespace detail
{
namespace error
{
struct UnknownOperator : ngraph_error
{
explicit UnknownOperator(const std::string& op_type)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
{
}
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " +
std::to_string(version)}
{
}
};
} // namespace error
class OperatorsBridge
{
public:
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version)
{
return instance().get_operator_set_version(version);
}
private:
std::unordered_map<std::string,
std::map<std::int64_t, std::function<NodeVector(const Node&)>>>
m_map;
static const OperatorsBridge& instance()
{
static OperatorsBridge instance;
return instance;
}
const Operator& get_operator(const std::string& name, std::int64_t version) const
{
auto op = m_map.find(name);
if (op == std::end(m_map))
{
throw error::UnknownOperator{name};
}
auto it = op->second.find(version);
if (it == std::end(op->second))
{
throw error::UnsupportedVersion{version};
}
return it->second;
}
const OperatorSet& get_operator_set_version_1() const
const OperatorSet& OperatorsBridge::get_operator_set_version_1() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -168,7 +106,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_2() const
const OperatorSet& OperatorsBridge::get_operator_set_version_2() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -178,7 +116,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_3() const
const OperatorSet& OperatorsBridge::get_operator_set_version_3() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -188,7 +126,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_4() const
const OperatorSet& OperatorsBridge::get_operator_set_version_4() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -198,7 +136,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_5() const
const OperatorSet& OperatorsBridge::get_operator_set_version_5() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -208,7 +146,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_6() const
const OperatorSet& OperatorsBridge::get_operator_set_version_6() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -218,7 +156,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_7() const
const OperatorSet& OperatorsBridge::get_operator_set_version_7() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -228,7 +166,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_8() const
const OperatorSet& OperatorsBridge::get_operator_set_version_8() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -238,7 +176,7 @@ namespace ngraph
return operator_set;
}
const OperatorSet& get_operator_set_version_9() const
const OperatorSet& OperatorsBridge::get_operator_set_version_9() const
{
static OperatorSet operator_set;
if (operator_set.empty())
......@@ -258,7 +196,7 @@ namespace ngraph
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const OperatorSet& get_operator_set_version(std::int64_t version) const
const OperatorSet& OperatorsBridge::get_operator_set_version(std::int64_t version) const
{
switch (version)
{
......@@ -278,7 +216,7 @@ namespace ngraph
#define REGISTER_OPERATOR(name_, version_, fn_) \
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
OperatorsBridge()
OperatorsBridge::OperatorsBridge()
{
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Add", 1, add);
......@@ -351,18 +289,6 @@ namespace ngraph
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
};
} // namespace detail
namespace ops_bridge
{
const OperatorSet& get_operator_set(std::int64_t version)
{
return detail::OperatorsBridge::get_operator_set(version);
}
} // namespace ops_bridge
} // namespace onnx_import
......
......@@ -17,6 +17,11 @@
#pragma once
#include <cstdint>
#include <map>
#include <string>
#include <unordered_map>
#include "ngraph/except.hpp"
#include "core/operator_set.hpp"
......@@ -24,11 +29,61 @@ namespace ngraph
{
namespace onnx_import
{
namespace ops_bridge
namespace error
{
struct UnknownOperator : ngraph_error
{
explicit UnknownOperator(const std::string& op_type)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
{
}
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " + std::to_string(version)}
{
}
};
} // namespace error
class OperatorsBridge
{
public:
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version)
{
return instance().get_operator_set_version(version);
}
private:
std::unordered_map<std::string, std::map<std::int64_t, Operator>> m_map;
OperatorsBridge();
static const OperatorsBridge& instance()
{
const OperatorSet& get_operator_set(std::int64_t version);
static OperatorsBridge instance;
return instance;
}
} // namespace ops_bridge
const OperatorSet& get_operator_set_version_1() const;
const OperatorSet& get_operator_set_version_2() const;
const OperatorSet& get_operator_set_version_3() const;
const OperatorSet& get_operator_set_version_4() const;
const OperatorSet& get_operator_set_version_5() const;
const OperatorSet& get_operator_set_version_6() const;
const OperatorSet& get_operator_set_version_7() const;
const OperatorSet& get_operator_set_version_8() const;
const OperatorSet& get_operator_set_version_9() const;
const OperatorSet& get_operator_set_version(std::int64_t version) const;
};
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment