Commit 9debc0bc authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

[ONNX] Add support to OperatorSet (Part 1) - namespace change only (#1812)

* onnx: add register operator macro
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: add set information to 'op' namespace
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 1d5f047a
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector abs(const Node& node) inline NodeVector abs(const Node& node)
{ {
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector add(const Node& node) inline NodeVector add(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector logical_and(const Node& node) inline NodeVector logical_and(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,13 +25,17 @@ namespace ngraph ...@@ -25,13 +25,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector average_pool(const Node& node) NodeVector average_pool(const Node& node)
{ {
return convpool::make_ng_pool<ngraph::op::AvgPool>(node); return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,6 +25,8 @@ namespace ngraph ...@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
/** /**
* @brief Convert ONNX AveragePool operation to an nGraph node. * @brief Convert ONNX AveragePool operation to an nGraph node.
...@@ -36,7 +38,9 @@ namespace ngraph ...@@ -36,7 +38,9 @@ namespace ngraph
*/ */
NodeVector average_pool(const Node& node); NodeVector average_pool(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector batch_norm(const Node& node) NodeVector batch_norm(const Node& node)
{ {
...@@ -58,7 +60,9 @@ namespace ngraph ...@@ -58,7 +60,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)}; return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -23,9 +23,14 @@ namespace ngraph ...@@ -23,9 +23,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector batch_norm(const Node& node); NodeVector batch_norm(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,6 +29,8 @@ namespace ngraph ...@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector cast(const Node& node) NodeVector cast(const Node& node)
{ {
...@@ -50,14 +52,18 @@ namespace ngraph ...@@ -50,14 +52,18 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break; case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break; case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break; case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break; case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type"; default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
} }
return {std::make_shared<ngraph::op::Convert>(data, elem_type)}; return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector cast(const Node& node); NodeVector cast(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector ceil(const Node& node) inline NodeVector ceil(const Node& node)
{ {
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -33,6 +33,8 @@ namespace ngraph ...@@ -33,6 +33,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector clip(const Node& node) NodeVector clip(const Node& node)
{ {
...@@ -40,24 +42,29 @@ namespace ngraph ...@@ -40,24 +42,29 @@ namespace ngraph
double max_value = double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max()); node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value = double min_value = node.get_attribute_value<double>(
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest()); "min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node = std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value}); ngraph::Shape{},
std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape()); max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node = std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value}); ngraph::Shape{},
std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape()); min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>( return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))}; max_value_node,
std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector clip(const Node& node); NodeVector clip(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -23,6 +23,8 @@ namespace ngraph ...@@ -23,6 +23,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector concat(const Node& node) NodeVector concat(const Node& node)
{ {
...@@ -32,7 +34,9 @@ namespace ngraph ...@@ -32,7 +34,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Concat>(inputs, axis)}; return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector concat(const Node& node); NodeVector concat(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -24,6 +24,8 @@ namespace ngraph ...@@ -24,6 +24,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
namespace namespace
{ {
...@@ -36,7 +38,8 @@ namespace ngraph ...@@ -36,7 +38,8 @@ namespace ngraph
} }
template <Tensor::Type> template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor) inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant(const Tensor& tensor)
{ {
throw error::tensor::unsupported_data_type{tensor}; throw error::tensor::unsupported_data_type{tensor};
} }
...@@ -114,7 +117,9 @@ namespace ngraph ...@@ -114,7 +117,9 @@ namespace ngraph
return {make_constant(node.get_attribute_value<Tensor>("value"))}; return {make_constant(node.get_attribute_value<Tensor>("value"))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector constant(const Node& node); NodeVector constant(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -36,6 +36,8 @@ namespace ngraph ...@@ -36,6 +36,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
namespace namespace
{ {
...@@ -63,7 +65,8 @@ namespace ngraph ...@@ -63,7 +65,8 @@ namespace ngraph
// initial bounds for splice // initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size()); std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()}; std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters->get_shape().size()); std::vector<std::size_t> filters_lower_bounds(
filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()}; std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group) for (std::size_t group{0}; group < groups; ++group)
...@@ -136,7 +139,9 @@ namespace ngraph ...@@ -136,7 +139,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)}; return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,6 +25,8 @@ namespace ngraph ...@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
/// \brief Performs ONNX Conv operation. /// \brief Performs ONNX Conv operation.
/// ///
...@@ -34,7 +36,9 @@ namespace ngraph ...@@ -34,7 +36,9 @@ namespace ngraph
/// operation. /// operation.
NodeVector conv(const Node& node); NodeVector conv(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -37,27 +37,34 @@ namespace ngraph ...@@ -37,27 +37,34 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector elu(const Node& node) NodeVector elu(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1); double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0}); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) + return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>( alpha_node *
std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node}; alpha_node};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector elu(const Node& node); NodeVector elu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector equal(const Node& node) inline NodeVector equal(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector exp(const Node& node) inline NodeVector exp(const Node& node)
{ {
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -24,6 +24,8 @@ namespace ngraph ...@@ -24,6 +24,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector flatten(const Node& node) NodeVector flatten(const Node& node)
{ {
...@@ -37,7 +39,9 @@ namespace ngraph ...@@ -37,7 +39,9 @@ namespace ngraph
return {reshape::flatten(data, axis)}; return {reshape::flatten(data, axis)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,9 +27,14 @@ namespace ngraph ...@@ -27,9 +27,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector flatten(const Node& node); NodeVector flatten(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector floor(const Node& node) inline NodeVector floor(const Node& node)
{ {
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,6 +31,8 @@ namespace ngraph ...@@ -31,6 +31,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector gemm(const Node& node) NodeVector gemm(const Node& node)
{ {
...@@ -60,13 +62,17 @@ namespace ngraph ...@@ -60,13 +62,17 @@ namespace ngraph
std::shared_ptr<ngraph::Node> a_dot_b = std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b); std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape()); alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b); a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> beta_node =
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta}); std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape()); beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c); input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
...@@ -75,7 +81,9 @@ namespace ngraph ...@@ -75,7 +81,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)}; return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,9 +27,14 @@ namespace ngraph ...@@ -27,9 +27,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector gemm(const Node& node); NodeVector gemm(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,15 +27,20 @@ namespace ngraph ...@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector greater(const Node& node) inline NodeVector greater(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -34,6 +34,8 @@ namespace ngraph ...@@ -34,6 +34,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector hard_sigmoid(const Node& node) NodeVector hard_sigmoid(const Node& node)
{ {
...@@ -42,11 +44,13 @@ namespace ngraph ...@@ -42,11 +44,13 @@ namespace ngraph
double alpha = node.get_attribute_value<double>("alpha", 0.2); double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5); double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape()); beta_node = make_broadcast_node(beta_node, data->get_shape());
...@@ -54,7 +58,8 @@ namespace ngraph ...@@ -54,7 +58,8 @@ namespace ngraph
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0}); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = make_broadcast_node(zero_node, data->get_shape());
...@@ -64,7 +69,9 @@ namespace ngraph ...@@ -64,7 +69,9 @@ namespace ngraph
alpha_node * data + beta_node))}; alpha_node * data + beta_node))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector hard_sigmoid(const Node& node); NodeVector hard_sigmoid(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,15 @@ namespace ngraph ...@@ -26,8 +26,15 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; } namespace set_1
} // namespace op {
inline NodeVector identity(const Node& node)
{
return {node.get_ng_inputs().at(0)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -37,6 +37,8 @@ namespace ngraph ...@@ -37,6 +37,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector leaky_relu(const Node& node) NodeVector leaky_relu(const Node& node)
{ {
...@@ -46,13 +48,16 @@ namespace ngraph ...@@ -46,13 +48,16 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1))) ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)"; << " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector leaky_relu(const Node& node); NodeVector leaky_relu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector less(const Node& node) inline NodeVector less(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector log(const Node& node) inline NodeVector log(const Node& node)
{ {
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,13 +30,17 @@ namespace ngraph ...@@ -30,13 +30,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector log_softmax(const Node& node) inline NodeVector log_softmax(const Node& node)
{ {
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))}; return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,6 +26,8 @@ namespace ngraph ...@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector lrn(const Node& node) NodeVector lrn(const Node& node)
{ {
...@@ -38,7 +40,9 @@ namespace ngraph ...@@ -38,7 +40,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)}; return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,13 @@ namespace ngraph ...@@ -25,9 +25,13 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector lrn(const Node& node); NodeVector lrn(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,6 +26,8 @@ namespace ngraph ...@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector matmul(const Node& node) inline NodeVector matmul(const Node& node)
{ {
...@@ -33,7 +35,9 @@ namespace ngraph ...@@ -33,7 +35,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector max(const Node& node) inline NodeVector max(const Node& node)
{ {
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node); return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,13 +25,17 @@ namespace ngraph ...@@ -25,13 +25,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector max_pool(const Node& node) NodeVector max_pool(const Node& node)
{ {
return convpool::make_ng_pool<ngraph::op::MaxPool>(node); return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,6 +25,8 @@ namespace ngraph ...@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
/** /**
* @brief Convert ONNX MaxPool operation to an nGraph node. * @brief Convert ONNX MaxPool operation to an nGraph node.
...@@ -36,7 +38,9 @@ namespace ngraph ...@@ -36,7 +38,9 @@ namespace ngraph
*/ */
NodeVector max_pool(const Node& node); NodeVector max_pool(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,6 +26,8 @@ namespace ngraph ...@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector mean(const Node& node) NodeVector mean(const Node& node)
{ {
...@@ -41,7 +43,9 @@ namespace ngraph ...@@ -41,7 +43,9 @@ namespace ngraph
return {sum / count}; return {sum / count};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector mean(const Node& node); NodeVector mean(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector min(const Node& node) inline NodeVector min(const Node& node)
{ {
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node); return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,15 +27,20 @@ namespace ngraph ...@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,9 +26,13 @@ namespace ngraph ...@@ -26,9 +26,13 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; } inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector logical_not(const Node& node) inline NodeVector logical_not(const Node& node)
{ {
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector logical_or(const Node& node) inline NodeVector logical_or(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector pow(const Node& node) inline NodeVector pow(const Node& node)
{ {
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -36,6 +36,8 @@ namespace ngraph ...@@ -36,6 +36,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector prelu(const Node& node) NodeVector prelu(const Node& node)
{ {
...@@ -47,8 +49,8 @@ namespace ngraph ...@@ -47,8 +49,8 @@ namespace ngraph
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1)) if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{ {
auto it = auto it = std::find(
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0)); std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it); auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index); slope = make_broadcast_node(slope, data->get_shape(), index);
} }
...@@ -61,7 +63,9 @@ namespace ngraph ...@@ -61,7 +63,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)}; return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector prelu(const Node& node); NodeVector prelu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,6 +30,8 @@ namespace ngraph ...@@ -30,6 +30,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector reciprocal(const Node& node) NodeVector reciprocal(const Node& node)
{ {
...@@ -42,7 +44,9 @@ namespace ngraph ...@@ -42,7 +44,9 @@ namespace ngraph
return {one_node / data}; return {one_node / data};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector reciprocal(const Node& node); NodeVector reciprocal(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,6 +31,8 @@ namespace ngraph ...@@ -31,6 +31,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector reduce_mean(const Node& node) NodeVector reduce_mean(const Node& node)
{ {
...@@ -55,6 +57,10 @@ namespace ngraph ...@@ -55,6 +57,10 @@ namespace ngraph
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)}; return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -38,6 +38,8 @@ namespace ngraph ...@@ -38,6 +38,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
/// \brief Compute the log sum of the input tensor's elements along the provided axes. /// \brief Compute the log sum of the input tensor's elements along the provided axes.
/// ///
...@@ -71,7 +73,8 @@ namespace ngraph ...@@ -71,7 +73,8 @@ namespace ngraph
inline NodeVector reduce_log_sum_exp(const Node& node) inline NodeVector reduce_log_sum_exp(const Node& node)
{ {
auto exp_node = std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0)); auto exp_node = std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0));
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(node, exp_node); auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, exp_node);
return {std::make_shared<ngraph::op::Log>(sum_node)}; return {std::make_shared<ngraph::op::Log>(sum_node)};
} }
...@@ -108,7 +111,8 @@ namespace ngraph ...@@ -108,7 +111,8 @@ namespace ngraph
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector ng_inputs{node.get_ng_inputs()};
auto square_node = auto square_node =
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(0)); std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(0));
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node); auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node);
return {std::make_shared<ngraph::op::Sqrt>(sum_node)}; return {std::make_shared<ngraph::op::Sqrt>(sum_node)};
} }
...@@ -212,6 +216,10 @@ namespace ngraph ...@@ -212,6 +216,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node)}; return {reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,6 +26,8 @@ namespace ngraph ...@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector relu(const Node& node) inline NodeVector relu(const Node& node)
{ {
...@@ -33,7 +35,9 @@ namespace ngraph ...@@ -33,7 +35,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))}; return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,6 +32,8 @@ namespace ngraph ...@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector reshape(const Node& node) NodeVector reshape(const Node& node)
{ {
...@@ -39,7 +41,8 @@ namespace ngraph ...@@ -39,7 +41,8 @@ namespace ngraph
auto data = ng_inputs.at(0); auto data = ng_inputs.at(0);
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {}); auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input. // If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2) if (output_shape.empty() && ng_inputs.size() == 2)
...@@ -58,14 +61,17 @@ namespace ngraph ...@@ -58,14 +61,17 @@ namespace ngraph
return {data}; return {data};
} }
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape); output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>( return {std::make_shared<ngraph::op::Reshape>(
data, data,
reshape::get_default_axis_vector(data_shape.size()), reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})}; Shape{output_shape})};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,6 +25,8 @@ namespace ngraph ...@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
/// ///
/// \brief Reshape the input tensor similar to numpy.reshape. /// \brief Reshape the input tensor similar to numpy.reshape.
...@@ -35,7 +37,9 @@ namespace ngraph ...@@ -35,7 +37,9 @@ namespace ngraph
/// ///
NodeVector reshape(const Node& node); NodeVector reshape(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -40,33 +40,42 @@ namespace ngraph ...@@ -40,33 +40,42 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector selu(const Node& node) NodeVector selu(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625); double alpha =
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875); node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma =
node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> gamma_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape()); gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape()); zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node * return {gamma_node * (std::make_shared<ngraph::op::Maximum>(data, zero_node) +
(std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>( alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - std::make_shared<ngraph::op::Minimum>(
data, zero_node)) -
alpha_node)}; alpha_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector selu(const Node& node); NodeVector selu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,6 +29,8 @@ namespace ngraph ...@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector shape(const Node& node) NodeVector shape(const Node& node)
{ {
...@@ -39,7 +41,9 @@ namespace ngraph ...@@ -39,7 +41,9 @@ namespace ngraph
ngraph::element::i64, Shape{data_shape.size()}, data_shape)}; ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector shape(const Node& node); NodeVector shape(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector sigmoid(const Node& node) inline NodeVector sigmoid(const Node& node)
{ {
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -34,6 +34,8 @@ namespace ngraph ...@@ -34,6 +34,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector slice(const Node& node) NodeVector slice(const Node& node)
{ {
...@@ -54,13 +56,16 @@ namespace ngraph ...@@ -54,13 +56,16 @@ namespace ngraph
size_t axis = axes.at(idx); size_t axis = axes.at(idx);
lower_bounds.at(axis) = lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis)); get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis)); upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
} }
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)}; return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector slice(const Node& node); NodeVector slice(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,6 +26,8 @@ namespace ngraph ...@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softmax(const Node& node) NodeVector softmax(const Node& node)
{ {
...@@ -50,7 +52,9 @@ namespace ngraph ...@@ -50,7 +52,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Softmax>(data, axes)}; return {std::make_shared<ngraph::op::Softmax>(data, axes)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softmax(const Node& node); NodeVector softmax(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,6 +32,8 @@ namespace ngraph ...@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softplus(const Node& node) NodeVector softplus(const Node& node)
{ {
...@@ -41,11 +43,13 @@ namespace ngraph ...@@ -41,11 +43,13 @@ namespace ngraph
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) + return {std::make_shared<ngraph::op::Log>(
one_node)}; std::make_shared<ngraph::op::Exp>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softplus(const Node& node); NodeVector softplus(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,6 +32,8 @@ namespace ngraph ...@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softsign(const Node& node) NodeVector softsign(const Node& node)
{ {
...@@ -44,7 +46,9 @@ namespace ngraph ...@@ -44,7 +46,9 @@ namespace ngraph
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)}; return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector softsign(const Node& node); NodeVector softsign(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -79,6 +79,8 @@ namespace ngraph ...@@ -79,6 +79,8 @@ namespace ngraph
} // namespace error } // namespace error
namespace op namespace op
{
namespace set_1
{ {
namespace detail namespace detail
{ {
...@@ -105,7 +107,8 @@ namespace ngraph ...@@ -105,7 +107,8 @@ namespace ngraph
upper_bounds.at(axis) = upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis)); get_valid_array_index(ends.at(index), node->get_shape().at(axis));
} }
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds); return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
} }
} // namespace detail } // namespace detail
...@@ -145,14 +148,16 @@ namespace ngraph ...@@ -145,14 +148,16 @@ namespace ngraph
for (const auto& length_part : length_parts) for (const auto& length_part : length_parts)
{ {
std::size_t end_index{start_index + length_part}; std::size_t end_index{start_index + length_part};
outputs.push_back( outputs.push_back(detail::make_ng_slice(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index})); input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index; start_index = end_index;
} }
return outputs; return outputs;
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector split(const Node& node); NodeVector split(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector sqrt(const Node& node) inline NodeVector sqrt(const Node& node)
{ {
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,6 +29,8 @@ namespace ngraph ...@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector squeeze(const Node& node) NodeVector squeeze(const Node& node)
{ {
...@@ -60,7 +62,9 @@ namespace ngraph ...@@ -60,7 +62,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector squeeze(const Node& node); NodeVector squeeze(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,15 +27,20 @@ namespace ngraph ...@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector sub(const Node& node) inline NodeVector sub(const Node& node)
{ {
NodeVector ng_inputs{ NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector sum(const Node& node) inline NodeVector sum(const Node& node)
{ {
return variadic::make_ng_variadic_op<ngraph::op::Add>(node); return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,13 +28,17 @@ namespace ngraph ...@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector tanh(const Node& node) inline NodeVector tanh(const Node& node)
{ {
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))}; return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,13 +35,16 @@ namespace ngraph ...@@ -35,13 +35,16 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector thresholded_relu(const Node& node) NodeVector thresholded_relu(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0); double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); alpha_node = make_broadcast_node(alpha_node, data->get_shape());
...@@ -51,7 +54,9 @@ namespace ngraph ...@@ -51,7 +54,9 @@ namespace ngraph
return {data * data_map}; return {data * data_map};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector thresholded_relu(const Node& node); NodeVector thresholded_relu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,18 +27,23 @@ namespace ngraph ...@@ -27,18 +27,23 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector transpose(const Node& node) NodeVector transpose(const Node& node)
{ {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0); std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {}); auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data) return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)}; : reshape::reorder_axes(data, permute_axes)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector transpose(const Node& node); NodeVector transpose(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector unsqueeze(const Node& node) NodeVector unsqueeze(const Node& node)
{ {
...@@ -52,7 +54,9 @@ namespace ngraph ...@@ -52,7 +54,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,9 +25,14 @@ namespace ngraph ...@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
NodeVector unsqueeze(const Node& node); NodeVector unsqueeze(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,6 +29,8 @@ namespace ngraph ...@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import namespace onnx_import
{ {
namespace op namespace op
{
namespace set_1
{ {
inline NodeVector logical_xor(const Node& node) inline NodeVector logical_xor(const Node& node)
{ {
...@@ -43,7 +45,9 @@ namespace ngraph ...@@ -43,7 +45,9 @@ namespace ngraph
std::make_shared<ngraph::op::And>(not_left, right))}; std::make_shared<ngraph::op::And>(not_left, right))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -118,83 +118,79 @@ namespace ngraph ...@@ -118,83 +118,79 @@ namespace ngraph
return instance; return instance;
} }
#define REGISTER_OPERATOR(name_, version_, fn_) \
m_map.emplace(name_, std::bind(op::set_##version_::fn_, std::placeholders::_1))
ops_bridge() ops_bridge()
{ {
m_map.emplace("Abs", std::bind(op::abs, std::placeholders::_1)); REGISTER_OPERATOR("Abs", 1, abs);
m_map.emplace("Add", std::bind(op::add, std::placeholders::_1)); REGISTER_OPERATOR("Add", 1, add);
m_map.emplace("And", std::bind(op::logical_and, std::placeholders::_1)); REGISTER_OPERATOR("And", 1, logical_and);
m_map.emplace("AveragePool", REGISTER_OPERATOR("AveragePool", 1, average_pool);
std::bind(op::average_pool, std::placeholders::_1)); REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
m_map.emplace("BatchNormalization", REGISTER_OPERATOR("Cast", 1, cast);
std::bind(op::batch_norm, std::placeholders::_1)); REGISTER_OPERATOR("Ceil", 1, ceil);
m_map.emplace("Cast", std::bind(op::cast, std::placeholders::_1)); REGISTER_OPERATOR("Clip", 1, clip);
m_map.emplace("Ceil", std::bind(op::ceil, std::placeholders::_1)); REGISTER_OPERATOR("Concat", 1, concat);
m_map.emplace("Clip", std::bind(op::clip, std::placeholders::_1)); REGISTER_OPERATOR("Constant", 1, constant);
m_map.emplace("Concat", std::bind(op::concat, std::placeholders::_1)); REGISTER_OPERATOR("Conv", 1, conv);
m_map.emplace("Constant", std::bind(op::constant, std::placeholders::_1)); REGISTER_OPERATOR("Div", 1, div);
m_map.emplace("Conv", std::bind(op::conv, std::placeholders::_1)); REGISTER_OPERATOR("Dropout", 1, identity);
m_map.emplace("Div", std::bind(op::div, std::placeholders::_1)); REGISTER_OPERATOR("Elu", 1, elu);
m_map.emplace("Dropout", std::bind(op::identity, std::placeholders::_1)); REGISTER_OPERATOR("Equal", 1, equal);
m_map.emplace("Elu", std::bind(op::elu, std::placeholders::_1)); REGISTER_OPERATOR("Exp", 1, exp);
m_map.emplace("Equal", std::bind(op::equal, std::placeholders::_1)); REGISTER_OPERATOR("Flatten", 1, flatten);
m_map.emplace("Exp", std::bind(op::exp, std::placeholders::_1)); REGISTER_OPERATOR("Floor", 1, floor);
m_map.emplace("Flatten", std::bind(op::flatten, std::placeholders::_1)); REGISTER_OPERATOR("Gemm", 1, gemm);
m_map.emplace("Floor", std::bind(op::floor, std::placeholders::_1)); REGISTER_OPERATOR("Greater", 1, greater);
m_map.emplace("Gemm", std::bind(op::gemm, std::placeholders::_1)); REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
m_map.emplace("Greater", std::bind(op::greater, std::placeholders::_1)); REGISTER_OPERATOR("Identity", 1, identity);
m_map.emplace("HardSigmoid", REGISTER_OPERATOR("LeakyRelu", 1, leaky_relu);
std::bind(op::hard_sigmoid, std::placeholders::_1)); REGISTER_OPERATOR("Less", 1, less);
m_map.emplace("Identity", std::bind(op::identity, std::placeholders::_1)); REGISTER_OPERATOR("Log", 1, log);
m_map.emplace("LeakyRelu", std::bind(op::leaky_relu, std::placeholders::_1)); REGISTER_OPERATOR("LogSoftmax", 1, log_softmax);
m_map.emplace("Less", std::bind(op::less, std::placeholders::_1)); REGISTER_OPERATOR("LRN", 1, lrn);
m_map.emplace("Log", std::bind(op::log, std::placeholders::_1)); REGISTER_OPERATOR("MatMul", 1, matmul);
m_map.emplace("LogSoftmax", std::bind(op::log_softmax, std::placeholders::_1)); REGISTER_OPERATOR("MaxPool", 1, max_pool);
m_map.emplace("LRN", std::bind(op::lrn, std::placeholders::_1)); REGISTER_OPERATOR("Max", 1, max);
m_map.emplace("MatMul", std::bind(op::matmul, std::placeholders::_1)); REGISTER_OPERATOR("Mean", 1, mean);
m_map.emplace("MaxPool", std::bind(op::max_pool, std::placeholders::_1)); REGISTER_OPERATOR("Min", 1, min);
m_map.emplace("Max", std::bind(op::max, std::placeholders::_1)); REGISTER_OPERATOR("Mul", 1, mul);
m_map.emplace("Mean", std::bind(op::mean, std::placeholders::_1)); REGISTER_OPERATOR("Neg", 1, neg);
m_map.emplace("Min", std::bind(op::min, std::placeholders::_1)); REGISTER_OPERATOR("Not", 1, logical_not);
m_map.emplace("Mul", std::bind(op::mul, std::placeholders::_1)); REGISTER_OPERATOR("Or", 1, logical_or);
m_map.emplace("Neg", std::bind(op::neg, std::placeholders::_1)); REGISTER_OPERATOR("Pow", 1, pow);
m_map.emplace("Not", std::bind(op::logical_not, std::placeholders::_1)); REGISTER_OPERATOR("PRelu", 1, prelu);
m_map.emplace("Or", std::bind(op::logical_or, std::placeholders::_1)); REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
m_map.emplace("Pow", std::bind(op::pow, std::placeholders::_1)); REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
m_map.emplace("PRelu", std::bind(op::prelu, std::placeholders::_1)); REGISTER_OPERATOR("ReduceLogSumExp", 1, reduce_log_sum_exp);
m_map.emplace("Reciprocal", std::bind(op::reciprocal, std::placeholders::_1)); REGISTER_OPERATOR("ReduceL1", 1, reduce_l1);
m_map.emplace("ReduceLogSum", REGISTER_OPERATOR("ReduceL2", 1, reduce_l2);
std::bind(op::reduce_log_sum, std::placeholders::_1)); REGISTER_OPERATOR("ReduceMax", 1, reduce_max);
m_map.emplace("ReduceLogSumExp", REGISTER_OPERATOR("ReduceMean", 1, reduce_mean);
std::bind(op::reduce_log_sum_exp, std::placeholders::_1)); REGISTER_OPERATOR("ReduceMin", 1, reduce_min);
m_map.emplace("ReduceL1", std::bind(op::reduce_l1, std::placeholders::_1)); REGISTER_OPERATOR("ReduceProd", 1, reduce_prod);
m_map.emplace("ReduceL2", std::bind(op::reduce_l2, std::placeholders::_1)); REGISTER_OPERATOR("ReduceSum", 1, reduce_sum);
m_map.emplace("ReduceMax", std::bind(op::reduce_max, std::placeholders::_1)); REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
m_map.emplace("ReduceMean", std::bind(op::reduce_mean, std::placeholders::_1)); REGISTER_OPERATOR("Relu", 1, relu);
m_map.emplace("ReduceMin", std::bind(op::reduce_min, std::placeholders::_1)); REGISTER_OPERATOR("Reshape", 1, reshape);
m_map.emplace("ReduceProd", std::bind(op::reduce_prod, std::placeholders::_1)); REGISTER_OPERATOR("Selu", 1, selu);
m_map.emplace("ReduceSum", std::bind(op::reduce_sum, std::placeholders::_1)); REGISTER_OPERATOR("Shape", 1, shape);
m_map.emplace("ReduceSumSquare", REGISTER_OPERATOR("Sigmoid", 1, sigmoid);
std::bind(op::reduce_sum_square, std::placeholders::_1)); REGISTER_OPERATOR("Slice", 1, slice);
m_map.emplace("Relu", std::bind(op::relu, std::placeholders::_1)); REGISTER_OPERATOR("Softmax", 1, softmax);
m_map.emplace("Reshape", std::bind(op::reshape, std::placeholders::_1)); REGISTER_OPERATOR("Softplus", 1, softplus);
m_map.emplace("Selu", std::bind(op::selu, std::placeholders::_1)); REGISTER_OPERATOR("Softsign", 1, softsign);
m_map.emplace("Shape", std::bind(op::shape, std::placeholders::_1)); REGISTER_OPERATOR("Split", 1, split);
m_map.emplace("Sigmoid", std::bind(op::sigmoid, std::placeholders::_1)); REGISTER_OPERATOR("Sqrt", 1, sqrt);
m_map.emplace("Slice", std::bind(op::slice, std::placeholders::_1)); REGISTER_OPERATOR("Squeeze", 1, squeeze);
m_map.emplace("Softmax", std::bind(op::softmax, std::placeholders::_1)); REGISTER_OPERATOR("Sub", 1, sub);
m_map.emplace("Softplus", std::bind(op::softplus, std::placeholders::_1)); REGISTER_OPERATOR("Sum", 1, sum);
m_map.emplace("Softsign", std::bind(op::softsign, std::placeholders::_1)); REGISTER_OPERATOR("Tanh", 1, tanh);
m_map.emplace("Split", std::bind(op::split, std::placeholders::_1)); REGISTER_OPERATOR("ThresholdedRelu", 1, thresholded_relu);
m_map.emplace("Sqrt", std::bind(op::sqrt, std::placeholders::_1)); REGISTER_OPERATOR("Transpose", 1, transpose);
m_map.emplace("Squeeze", std::bind(op::squeeze, std::placeholders::_1)); REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze);
m_map.emplace("Sub", std::bind(op::sub, std::placeholders::_1)); REGISTER_OPERATOR("Xor", 1, logical_xor);
m_map.emplace("Sum", std::bind(op::sum, std::placeholders::_1));
m_map.emplace("Tanh", std::bind(op::tanh, std::placeholders::_1));
m_map.emplace("ThresholdedRelu",
std::bind(op::thresholded_relu, std::placeholders::_1));
m_map.emplace("Transpose", std::bind(op::transpose, std::placeholders::_1));
m_map.emplace("Unsqueeze", std::bind(op::unsqueeze, std::placeholders::_1));
m_map.emplace("Xor", std::bind(op::logical_xor, std::placeholders::_1));
} }
NodeVector operator()(const Node& node) const NodeVector operator()(const Node& node) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment