Unverified Commit c85ff3b8 authored by Artur Wojcik's avatar Artur Wojcik Committed by GitHub

onnx: flatten operatos bridge hierarchy (#1846)

Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 3167b167
...@@ -62,7 +62,7 @@ namespace ngraph ...@@ -62,7 +62,7 @@ namespace ngraph
std::vector<std::shared_ptr<Function>> output_functions; std::vector<std::shared_ptr<Function>> output_functions;
Model model{model_proto}; Model model{model_proto};
Graph graph{model_proto.graph(), Graph graph{model_proto.graph(),
ops_bridge::get_operator_set(model.get_opset_version())}; OperatorsBridge::get_operator_set(model.get_opset_version())};
for (const auto& output : graph.get_outputs()) for (const auto& output : graph.get_outputs())
{ {
output_functions.emplace_back(std::make_shared<Function>( output_functions.emplace_back(std::make_shared<Function>(
......
...@@ -87,166 +87,104 @@ namespace ngraph ...@@ -87,166 +87,104 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace detail const OperatorSet& OperatorsBridge::get_operator_set_version_1() const
{ {
namespace error static OperatorSet operator_set;
if (operator_set.empty())
{ {
struct UnknownOperator : ngraph_error for (const auto& op : m_map)
{ {
explicit UnknownOperator(const std::string& op_type) for (const auto& it : op.second)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
{ {
} if (it.first == 1)
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " +
std::to_string(version)}
{
}
};
} // namespace error
class OperatorsBridge
{
public:
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version)
{
return instance().get_operator_set_version(version);
}
private:
std::unordered_map<std::string,
std::map<std::int64_t, std::function<NodeVector(const Node&)>>>
m_map;
static const OperatorsBridge& instance()
{
static OperatorsBridge instance;
return instance;
}
const Operator& get_operator(const std::string& name, std::int64_t version) const
{
auto op = m_map.find(name);
if (op == std::end(m_map))
{
throw error::UnknownOperator{name};
}
auto it = op->second.find(version);
if (it == std::end(op->second))
{
throw error::UnsupportedVersion{version};
}
return it->second;
}
const OperatorSet& get_operator_set_version_1() const
{
static OperatorSet operator_set;
if (operator_set.empty())
{
for (const auto& op : m_map)
{ {
for (const auto& it : op.second) operator_set.emplace(op.first, it.second);
{
if (it.first == 1)
{
operator_set.emplace(op.first, it.second);
}
}
} }
} }
return operator_set;
} }
}
return operator_set;
}
const OperatorSet& get_operator_set_version_2() const const OperatorSet& OperatorsBridge::get_operator_set_version_2() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_1(); operator_set = get_operator_set_version_1();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_3() const const OperatorSet& OperatorsBridge::get_operator_set_version_3() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_2(); operator_set = get_operator_set_version_2();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_4() const const OperatorSet& OperatorsBridge::get_operator_set_version_4() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_3(); operator_set = get_operator_set_version_3();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_5() const const OperatorSet& OperatorsBridge::get_operator_set_version_5() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_4(); operator_set = get_operator_set_version_4();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_6() const const OperatorSet& OperatorsBridge::get_operator_set_version_6() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_5(); operator_set = get_operator_set_version_5();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_7() const const OperatorSet& OperatorsBridge::get_operator_set_version_7() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_6(); operator_set = get_operator_set_version_6();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_8() const const OperatorSet& OperatorsBridge::get_operator_set_version_8() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_7(); operator_set = get_operator_set_version_7();
} }
return operator_set; return operator_set;
} }
const OperatorSet& get_operator_set_version_9() const const OperatorSet& OperatorsBridge::get_operator_set_version_9() const
{ {
static OperatorSet operator_set; static OperatorSet operator_set;
if (operator_set.empty()) if (operator_set.empty())
{ {
operator_set = get_operator_set_version_8(); operator_set = get_operator_set_version_8();
} }
return operator_set; return operator_set;
} }
#define OPERATOR_SET_NAME(version_) get_operator_set_version_##version_() #define OPERATOR_SET_NAME(version_) get_operator_set_version_##version_()
...@@ -258,111 +196,99 @@ namespace ngraph ...@@ -258,111 +196,99 @@ namespace ngraph
#define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION) #define DEFAULT_OPERATOR_SET() return OPERATOR_SET_NAME_HELPER(ONNX_OPSET_VERSION)
const OperatorSet& get_operator_set_version(std::int64_t version) const const OperatorSet& OperatorsBridge::get_operator_set_version(std::int64_t version) const
{ {
switch (version) switch (version)
{ {
GET_OPERATOR_SET(1); GET_OPERATOR_SET(1);
GET_OPERATOR_SET(2); GET_OPERATOR_SET(2);
GET_OPERATOR_SET(3); GET_OPERATOR_SET(3);
GET_OPERATOR_SET(4); GET_OPERATOR_SET(4);
GET_OPERATOR_SET(5); GET_OPERATOR_SET(5);
GET_OPERATOR_SET(6); GET_OPERATOR_SET(6);
GET_OPERATOR_SET(7); GET_OPERATOR_SET(7);
GET_OPERATOR_SET(8); GET_OPERATOR_SET(8);
GET_OPERATOR_SET(9); GET_OPERATOR_SET(9);
default: DEFAULT_OPERATOR_SET(); default: DEFAULT_OPERATOR_SET();
} }
} }
#define REGISTER_OPERATOR(name_, version_, fn_) \ #define REGISTER_OPERATOR(name_, version_, fn_) \
m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1)) m_map[name_].emplace(version_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
OperatorsBridge() OperatorsBridge::OperatorsBridge()
{
REGISTER_OPERATOR("Abs", 1, abs);
REGISTER_OPERATOR("Add", 1, add);
REGISTER_OPERATOR("And", 1, logical_and);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
REGISTER_OPERATOR("Cast", 1, cast);
REGISTER_OPERATOR("Ceil", 1, ceil);
REGISTER_OPERATOR("Clip", 1, clip);
REGISTER_OPERATOR("Concat", 1, concat);
REGISTER_OPERATOR("Constant", 1, constant);
REGISTER_OPERATOR("Conv", 1, conv);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Dropout", 1, identity);
REGISTER_OPERATOR("Elu", 1, elu);
REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity);
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
REGISTER_OPERATOR("Less", 1, less);
REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Mul", 1, mul);
REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or);
REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
REGISTER_OPERATOR("ReduceLogSumExp", 1, reduce_log_sum_exp);
REGISTER_OPERATOR("ReduceL1", 1, reduce_l1);
REGISTER_OPERATOR("ReduceL2", 1, reduce_l2);
REGISTER_OPERATOR("ReduceMax", 1, reduce_max);
REGISTER_OPERATOR("ReduceMean", 1, reduce_mean);
REGISTER_OPERATOR("ReduceMin", 1, reduce_min);
REGISTER_OPERATOR("ReduceProd", 1, reduce_prod);
REGISTER_OPERATOR("ReduceSum", 1, reduce_sum);
REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Sigmoid", 1, sigmoid);
REGISTER_OPERATOR("Slice", 1, slice);
REGISTER_OPERATOR("Softmax", 1, softmax);
REGISTER_OPERATOR("Softplus", 1, softplus);
REGISTER_OPERATOR("Softsign", 1, softsign);
REGISTER_OPERATOR("Split", 1, split);
REGISTER_OPERATOR("Sqrt", 1, sqrt);
REGISTER_OPERATOR("Squeeze", 1, squeeze);
REGISTER_OPERATOR("Sub", 1, sub);
REGISTER_OPERATOR("Sum", 1, sum);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
};
} // namespace detail
namespace ops_bridge
{ {
const OperatorSet& get_operator_set(std::int64_t version) REGISTER_OPERATOR("Abs", 1, abs);
{ REGISTER_OPERATOR("Add", 1, add);
return detail::OperatorsBridge::get_operator_set(version); REGISTER_OPERATOR("And", 1, logical_and);
} REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
} // namespace ops_bridge REGISTER_OPERATOR("Cast", 1, cast);
REGISTER_OPERATOR("Ceil", 1, ceil);
REGISTER_OPERATOR("Clip", 1, clip);
REGISTER_OPERATOR("Concat", 1, concat);
REGISTER_OPERATOR("Constant", 1, constant);
REGISTER_OPERATOR("Conv", 1, conv);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Dropout", 1, identity);
REGISTER_OPERATOR("Elu", 1, elu);
REGISTER_OPERATOR("Equal", 1, equal);
REGISTER_OPERATOR("Exp", 1, exp);
REGISTER_OPERATOR("Flatten", 1, flatten);
REGISTER_OPERATOR("Floor", 1, floor);
REGISTER_OPERATOR("Gemm", 1, gemm);
REGISTER_OPERATOR("GlobalAveragePool", 1, global_average_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity);
REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
REGISTER_OPERATOR("Less", 1, less);
REGISTER_OPERATOR("Log", 1, log);
REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
REGISTER_OPERATOR("LRN", 1, lrn);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("Max", 1, max);
REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Mul", 1, mul);
REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or);
REGISTER_OPERATOR("Pow", 1, pow);
REGISTER_OPERATOR("PRelu", 1, prelu);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
REGISTER_OPERATOR("ReduceLogSumExp", 1, reduce_log_sum_exp);
REGISTER_OPERATOR("ReduceL1", 1, reduce_l1);
REGISTER_OPERATOR("ReduceL2", 1, reduce_l2);
REGISTER_OPERATOR("ReduceMax", 1, reduce_max);
REGISTER_OPERATOR("ReduceMean", 1, reduce_mean);
REGISTER_OPERATOR("ReduceMin", 1, reduce_min);
REGISTER_OPERATOR("ReduceProd", 1, reduce_prod);
REGISTER_OPERATOR("ReduceSum", 1, reduce_sum);
REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Sigmoid", 1, sigmoid);
REGISTER_OPERATOR("Slice", 1, slice);
REGISTER_OPERATOR("Softmax", 1, softmax);
REGISTER_OPERATOR("Softplus", 1, softplus);
REGISTER_OPERATOR("Softsign", 1, softsign);
REGISTER_OPERATOR("Split", 1, split);
REGISTER_OPERATOR("Sqrt", 1, sqrt);
REGISTER_OPERATOR("Squeeze", 1, squeeze);
REGISTER_OPERATOR("Sub", 1, sub);
REGISTER_OPERATOR("Sum", 1, sum);
REGISTER_OPERATOR("Tanh", 1, tanh);
REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
REGISTER_OPERATOR("Transpose", 1, transpose);
REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
REGISTER_OPERATOR("Xor", 1, logical_xor);
}
} // namespace onnx_import } // namespace onnx_import
......
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <map>
#include <string>
#include <unordered_map>
#include "ngraph/except.hpp"
#include "core/operator_set.hpp" #include "core/operator_set.hpp"
...@@ -24,11 +29,61 @@ namespace ngraph ...@@ -24,11 +29,61 @@ namespace ngraph
{ {
namespace onnx_import namespace onnx_import
{ {
namespace ops_bridge namespace error
{
struct UnknownOperator : ngraph_error
{
explicit UnknownOperator(const std::string& op_type)
: ngraph_error{"unknown operator: \"" + op_type + "\""}
{
}
};
struct UnsupportedVersion : ngraph_error
{
explicit UnsupportedVersion(std::int64_t version)
: ngraph_error{"unsupported operator set version: " + std::to_string(version)}
{
}
};
} // namespace error
class OperatorsBridge
{ {
const OperatorSet& get_operator_set(std::int64_t version); public:
OperatorsBridge(const OperatorsBridge&) = delete;
OperatorsBridge& operator=(const OperatorsBridge&) = delete;
OperatorsBridge(OperatorsBridge&&) = delete;
OperatorsBridge& operator=(OperatorsBridge&&) = delete;
static const OperatorSet& get_operator_set(std::int64_t version)
{
return instance().get_operator_set_version(version);
}
private:
std::unordered_map<std::string, std::map<std::int64_t, Operator>> m_map;
OperatorsBridge();
static const OperatorsBridge& instance()
{
static OperatorsBridge instance;
return instance;
}
} // namespace ops_bridge const OperatorSet& get_operator_set_version_1() const;
const OperatorSet& get_operator_set_version_2() const;
const OperatorSet& get_operator_set_version_3() const;
const OperatorSet& get_operator_set_version_4() const;
const OperatorSet& get_operator_set_version_5() const;
const OperatorSet& get_operator_set_version_6() const;
const OperatorSet& get_operator_set_version_7() const;
const OperatorSet& get_operator_set_version_8() const;
const OperatorSet& get_operator_set_version_9() const;
const OperatorSet& get_operator_set_version(std::int64_t version) const;
};
} // namespace onnx_import } // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment