Commit 9debc0bc authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

[ONNX] Add support to OperatorSet (Part 1) - namespace change only (#1812)

* onnx: add register operator macro
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: add set information to 'op' namespace
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 1d5f047a
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector abs(const Node& node)
{
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector add(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector logical_and(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,13 +25,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector average_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
......@@ -36,7 +38,9 @@ namespace ngraph
*/
NodeVector average_pool(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector batch_norm(const Node& node)
{
......@@ -58,7 +60,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -23,9 +23,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector batch_norm(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector cast(const Node& node)
{
......@@ -50,14 +52,18 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector cast(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector ceil(const Node& node)
{
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -33,6 +33,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector clip(const Node& node)
{
......@@ -40,24 +42,29 @@ namespace ngraph
double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value =
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest());
double min_value = node.get_attribute_value<double>(
"min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value});
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value});
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
max_value_node,
std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector clip(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -23,6 +23,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector concat(const Node& node)
{
......@@ -32,7 +34,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector concat(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -24,6 +24,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
namespace
{
......@@ -36,7 +38,8 @@ namespace ngraph
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
......@@ -114,7 +117,9 @@ namespace ngraph
return {make_constant(node.get_attribute_value<Tensor>("value"))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,10 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector constant(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -36,6 +36,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
namespace
{
......@@ -63,7 +65,8 @@ namespace ngraph
// initial bounds for splice
std::vector<std::size_t> data_lower_bounds(data->get_shape().size());
std::vector<std::size_t> data_upper_bounds{data->get_shape()};
std::vector<std::size_t> filters_lower_bounds(filters->get_shape().size());
std::vector<std::size_t> filters_lower_bounds(
filters->get_shape().size());
std::vector<std::size_t> filters_upper_bounds{filters->get_shape()};
for (std::size_t group{0}; group < groups; ++group)
......@@ -136,7 +139,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(conv_node, broadcasted_bias)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Performs ONNX Conv operation.
///
......@@ -34,7 +36,9 @@ namespace ngraph
/// operation.
NodeVector conv(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector div(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -37,27 +37,34 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector elu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
alpha_node *
std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector elu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector equal(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector exp(const Node& node)
{
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -24,6 +24,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector flatten(const Node& node)
{
......@@ -37,7 +39,9 @@ namespace ngraph
return {reshape::flatten(data, axis)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,9 +27,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector flatten(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector floor(const Node& node)
{
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -31,6 +31,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector gemm(const Node& node)
{
......@@ -60,13 +62,17 @@ namespace ngraph
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
......@@ -75,7 +81,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,9 +27,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector gemm(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -34,6 +34,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector hard_sigmoid(const Node& node)
{
......@@ -42,11 +44,13 @@ namespace ngraph
double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
......@@ -54,7 +58,8 @@ namespace ngraph
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
......@@ -64,7 +69,9 @@ namespace ngraph
alpha_node * data + beta_node))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector hard_sigmoid(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,15 @@ namespace ngraph
{
namespace op
{
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; }
} // namespace op
namespace set_1
{
inline NodeVector identity(const Node& node)
{
return {node.get_ng_inputs().at(0)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -37,6 +37,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector leaky_relu(const Node& node)
{
......@@ -46,13 +48,16 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector leaky_relu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector less(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector log(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -30,13 +30,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector log_softmax(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector lrn(const Node& node)
{
......@@ -38,7 +40,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,13 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector lrn(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector matmul(const Node& node)
{
......@@ -33,7 +35,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector max(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,13 +25,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector max_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
/**
* @brief Convert ONNX MaxPool operation to an nGraph node.
......@@ -36,7 +38,9 @@ namespace ngraph
*/
NodeVector max_pool(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean(const Node& node)
{
......@@ -41,7 +43,9 @@ namespace ngraph
return {sum / count};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,10 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector min(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,9 +26,13 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector logical_not(const Node& node)
{
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector logical_or(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector pow(const Node& node)
{
......@@ -35,7 +37,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -36,6 +36,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector prelu(const Node& node)
{
......@@ -47,8 +49,8 @@ namespace ngraph
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
......@@ -61,7 +63,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector prelu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -30,6 +30,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reciprocal(const Node& node)
{
......@@ -42,7 +44,9 @@ namespace ngraph
return {one_node / data};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reciprocal(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -31,6 +31,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reduce_mean(const Node& node)
{
......@@ -55,6 +57,10 @@ namespace ngraph
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -38,6 +38,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// \brief Compute the log sum of the input tensor's elements along the provided axes.
///
......@@ -71,7 +73,8 @@ namespace ngraph
inline NodeVector reduce_log_sum_exp(const Node& node)
{
auto exp_node = std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0));
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(node, exp_node);
auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, exp_node);
return {std::make_shared<ngraph::op::Log>(sum_node)};
}
......@@ -108,7 +111,8 @@ namespace ngraph
NodeVector ng_inputs{node.get_ng_inputs()};
auto square_node =
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(0));
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node);
auto sum_node =
reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node);
return {std::make_shared<ngraph::op::Sqrt>(sum_node)};
}
......@@ -212,6 +216,10 @@ namespace ngraph
return {reduction::make_ng_reduction_op<ngraph::op::Sum>(node, square_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector relu(const Node& node)
{
......@@ -33,7 +35,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector reshape(const Node& node)
{
......@@ -39,7 +41,8 @@ namespace ngraph
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {});
auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
......@@ -58,14 +61,17 @@ namespace ngraph
return {data};
}
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,6 +25,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
///
/// \brief Reshape the input tensor similar to numpy.reshape.
......@@ -35,7 +37,9 @@ namespace ngraph
///
NodeVector reshape(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -40,33 +40,42 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector selu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
double alpha =
node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma =
node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> gamma_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node *
(std::make_shared<ngraph::op::Maximum>(data, zero_node) +
return {gamma_node * (std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
std::make_shared<ngraph::op::Minimum>(
data, zero_node)) -
alpha_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector selu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector shape(const Node& node)
{
......@@ -39,7 +41,9 @@ namespace ngraph
ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector shape(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector sigmoid(const Node& node)
{
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -34,6 +34,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector slice(const Node& node)
{
......@@ -54,13 +56,16 @@ namespace ngraph
size_t axis = axes.at(idx);
lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis));
upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector slice(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,6 +26,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softmax(const Node& node)
{
......@@ -50,7 +52,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softmax(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softplus(const Node& node)
{
......@@ -41,11 +43,13 @@ namespace ngraph
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) +
one_node)};
return {std::make_shared<ngraph::op::Log>(
std::make_shared<ngraph::op::Exp>(data) + one_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softplus(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -32,6 +32,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softsign(const Node& node)
{
......@@ -44,7 +46,9 @@ namespace ngraph
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector softsign(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -79,6 +79,8 @@ namespace ngraph
} // namespace error
namespace op
{
namespace set_1
{
namespace detail
{
......@@ -105,7 +107,8 @@ namespace ngraph
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
}
} // namespace detail
......@@ -145,14 +148,16 @@ namespace ngraph
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index}));
outputs.push_back(detail::make_ng_slice(
input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector split(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector sqrt(const Node& node)
{
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector squeeze(const Node& node)
{
......@@ -60,7 +62,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector squeeze(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,15 +27,20 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,13 +27,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector sum(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,13 +28,17 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector tanh(const Node& node)
{
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -35,13 +35,16 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector thresholded_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
......@@ -51,7 +54,9 @@ namespace ngraph
return {data * data_map};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector thresholded_relu(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,18 +27,23 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector transpose(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {});
auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector transpose(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector unsqueeze(const Node& node)
{
......@@ -52,7 +54,9 @@ namespace ngraph
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,9 +25,14 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector unsqueeze(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,6 +29,8 @@ namespace ngraph
namespace onnx_import
{
namespace op
{
namespace set_1
{
inline NodeVector logical_xor(const Node& node)
{
......@@ -43,7 +45,9 @@ namespace ngraph
std::make_shared<ngraph::op::And>(not_left, right))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment