Commit 9debc0bc authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

[ONNX] Add support to OperatorSet (Part 1) - namespace change only (#1812)

* onnx: add register operator macro
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: add set information to 'op' namespace
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 1d5f047a
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector abs(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))}; inline NodeVector abs(const Node& node)
} {
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector add(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector add(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_and(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_and(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,12 +26,16 @@ namespace ngraph ...@@ -26,12 +26,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector average_pool(const Node& node) namespace set_1
{ {
return convpool::make_ng_pool<ngraph::op::AvgPool>(node); NodeVector average_pool(const Node& node)
} {
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,17 +26,21 @@ namespace ngraph ...@@ -26,17 +26,21 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/** namespace set_1
* @brief Convert ONNX AveragePool operation to an nGraph node. {
* /**
* @param node The ONNX node object representing this operation. * @brief Convert ONNX AveragePool operation to an nGraph node.
* *
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool * @param node The ONNX node object representing this operation.
* operation. *
*/ * @return The vector containing Ngraph nodes producing output of ONNX AveragePool
NodeVector average_pool(const Node& node); * operation.
*/
} // namespace op NodeVector average_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,38 +28,42 @@ namespace ngraph ...@@ -28,38 +28,42 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector batch_norm(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector batch_norm(const Node& node)
auto x = inputs.at(0); {
auto scale = inputs.at(1); NodeVector inputs{node.get_ng_inputs()};
auto bias = inputs.at(2); auto x = inputs.at(0);
std::shared_ptr<ngraph::Node> mean{nullptr}; auto scale = inputs.at(1);
std::shared_ptr<ngraph::Node> var{nullptr}; auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)}; int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)}; int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support // TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false; bool training = false;
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported."; ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
mean = inputs.at(3); mean = inputs.at(3);
var = inputs.at(4); var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>( return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)}; epsilon, scale, bias, x, mean, var, training)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
} }
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -24,9 +24,14 @@ namespace ngraph ...@@ -24,9 +24,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector batch_norm(const Node& node); namespace set_1
} // namespace op {
NodeVector batch_norm(const Node& node);
} // namespace onnx_import } // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -30,34 +30,40 @@ namespace ngraph ...@@ -30,34 +30,40 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector cast(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector cast(const Node& node)
int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type;
switch (target_type)
{ {
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break; auto data = node.get_ng_inputs().at(0);
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break; int64_t target_type = node.get_attribute_value<int64_t>("to");
case onnx::TensorProto_DataType_FLOAT16: element::Type elem_type;
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break; switch (target_type)
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break; {
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break; case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break; case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break; case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break; case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break; case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break; case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break; case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type"; case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
} }
return {std::make_shared<ngraph::op::Convert>(data, elem_type)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector cast(const Node& node); namespace set_1
} // namespace op {
NodeVector cast(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector ceil(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))}; inline NodeVector ceil(const Node& node)
} {
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -34,30 +34,37 @@ namespace ngraph ...@@ -34,30 +34,37 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector clip(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector clip(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double max_value = double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max()); node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value = double min_value = node.get_attribute_value<double>(
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest()); "min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node = std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value}); ngraph::Shape{},
max_value_node = make_broadcast_node(max_value_node, data->get_shape()); std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node = std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value}); ngraph::Shape{},
min_value_node = make_broadcast_node(min_value_node, data->get_shape()); std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>( return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))}; max_value_node,
} std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector clip(const Node& node); namespace set_1
} // namespace op {
NodeVector clip(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -24,16 +24,20 @@ namespace ngraph ...@@ -24,16 +24,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector concat(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector concat(const Node& node)
auto axis = node.get_attribute_value<int64_t>("axis"); {
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
return {std::make_shared<ngraph::op::Concat>(inputs, axis)}; return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
} }
} // namespace op } // namespace set_1
} // namespace onnx_import } //namespace op
} // namespace ngraph } // namespace onnx_import
} // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector concat(const Node& node); namespace set_1
} // namespace op {
NodeVector concat(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -25,96 +25,101 @@ namespace ngraph ...@@ -25,96 +25,101 @@ namespace ngraph
{ {
namespace op namespace op
{ {
namespace namespace set_1
{ {
template <typename T> namespace
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{ {
return std::make_shared<ngraph::op::Constant>( template <typename T>
type, tensor.get_shape(), tensor.get_data<T>()); inline std::shared_ptr<ngraph::op::Constant>
} __make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
template <Tensor::Type> template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor) inline std::shared_ptr<ngraph::op::Constant>
{ make_ng_constant(const Tensor& tensor)
throw error::tensor::unsupported_data_type{tensor}; {
} throw error::tensor::unsupported_data_type{tensor};
}
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor) make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{ {
return __make_ng_constant<float>(element::f32, tensor); return __make_ng_constant<float>(element::f32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor) make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{ {
return __make_ng_constant<float>(element::f32, tensor); return __make_ng_constant<float>(element::f32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor) make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{ {
return __make_ng_constant<double>(element::f64, tensor); return __make_ng_constant<double>(element::f64, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor) make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{ {
return __make_ng_constant<int32_t>(element::i32, tensor); return __make_ng_constant<int32_t>(element::i32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor) make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{ {
return __make_ng_constant<int64_t>(element::i64, tensor); return __make_ng_constant<int64_t>(element::i64, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor) make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{ {
return __make_ng_constant<uint32_t>(element::u32, tensor); return __make_ng_constant<uint32_t>(element::u32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor) make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{ {
return __make_ng_constant<uint64_t>(element::u64, tensor); return __make_ng_constant<uint64_t>(element::u64, tensor);
} }
inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor) inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor)
{ {
#define MAKE_NG_CONSTANT(data_type_) \ #define MAKE_NG_CONSTANT(data_type_) \
case data_type_: return make_ng_constant<data_type_>(tensor) case data_type_: return make_ng_constant<data_type_>(tensor)
switch (tensor.get_type()) switch (tensor.get_type())
{ {
MAKE_NG_CONSTANT(Tensor::Type::float16); MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32); MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64); MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32); MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64); MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32); MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64); MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor}; default: throw error::tensor::invalid_data_type{tensor};
}
} }
} }
}
NodeVector constant(const onnx_import::Node& node) NodeVector constant(const onnx_import::Node& node)
{ {
return {make_constant(node.get_attribute_value<Tensor>("value"))}; return {make_constant(node.get_attribute_value<Tensor>("value"))};
} }
} // namespace set_1
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,9 +26,13 @@ namespace ngraph ...@@ -26,9 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector constant(const Node& node); namespace set_1
{
NodeVector constant(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
This diff is collapsed.
...@@ -26,15 +26,19 @@ namespace ngraph ...@@ -26,15 +26,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Performs ONNX Conv operation. namespace set_1
/// {
/// \param node The ONNX node object representing this operation. /// \brief Performs ONNX Conv operation.
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution /// \param node The ONNX node object representing this operation.
/// operation. ///
NodeVector conv(const Node& node); /// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation.
} // namespace op NodeVector conv(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector div(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector div(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -38,26 +38,33 @@ namespace ngraph ...@@ -38,26 +38,33 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector elu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector elu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) + return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>( alpha_node *
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - std::make_shared<ngraph::op::Exp>(
alpha_node}; std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
} alpha_node};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector elu(const Node& node); namespace set_1
} // namespace op {
NodeVector elu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector equal(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector equal(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector exp(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))}; inline NodeVector exp(const Node& node)
} {
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,19 +25,23 @@ namespace ngraph ...@@ -25,19 +25,23 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector flatten(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector flatten(const Node& node)
auto data = inputs.at(0); {
auto axis = node.get_attribute_value<int64_t>("axis", 1); NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size())) ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid."; << "provided 'axis' attribute is not valid.";
return {reshape::flatten(data, axis)}; return {reshape::flatten(data, axis)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,8 +28,13 @@ namespace ngraph ...@@ -28,8 +28,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector flatten(const Node& node); namespace set_1
} // namespace op {
NodeVector flatten(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector floor(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))}; inline NodeVector floor(const Node& node)
} {
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,50 +32,58 @@ namespace ngraph ...@@ -32,50 +32,58 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector gemm(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector gemm(const Node& node)
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{ {
input_b = reshape::transpose(input_b); NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{
input_b = reshape::transpose(input_b);
}
// code from python not implemented in c++ yet.
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
} }
// code from python not implemented in c++ yet. } // namespace set_1
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,8 +28,13 @@ namespace ngraph ...@@ -28,8 +28,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector gemm(const Node& node); namespace set_1
} // namespace op {
NodeVector gemm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector greater(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector greater(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,36 +35,43 @@ namespace ngraph ...@@ -35,36 +35,43 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector hard_sigmoid(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector hard_sigmoid(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.2); double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5); double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> beta_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta}); std::make_shared<ngraph::op::Constant>(
beta_node = make_broadcast_node(beta_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>( return {std::make_shared<ngraph::op::Maximum>(
zero_node, zero_node,
std::make_shared<ngraph::op::Minimum>(one_node, std::make_shared<ngraph::op::Minimum>(one_node,
alpha_node * data + beta_node))}; alpha_node * data + beta_node))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector hard_sigmoid(const Node& node); namespace set_1
} // namespace op {
NodeVector hard_sigmoid(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,15 @@ namespace ngraph ...@@ -26,8 +26,15 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; } namespace set_1
} // namespace op {
inline NodeVector identity(const Node& node)
{
return {node.get_ng_inputs().at(0)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -38,21 +38,26 @@ namespace ngraph ...@@ -38,21 +38,26 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector leaky_relu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector leaky_relu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 0.01); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1))) ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)"; << " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{alpha});
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; alpha_node = make_broadcast_node(alpha_node, data->get_shape());
} return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector leaky_relu(const Node& node); namespace set_1
} // namespace op {
NodeVector leaky_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector less(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector less(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector log(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))}; inline NodeVector log(const Node& node)
} {
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,13 +31,17 @@ namespace ngraph ...@@ -31,13 +31,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector log_softmax(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))}; inline NodeVector log_softmax(const Node& node)
} {
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace op } // namespace set_1
} // namespace onnx_import } //namespace op
} // namespace ngraph } // namespace onnx_import
} // namespace ngraph
...@@ -27,18 +27,22 @@ namespace ngraph ...@@ -27,18 +27,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector lrn(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector lrn(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1e-4); {
double beta = node.get_attribute_value<double>("beta", 0.75); auto data = node.get_ng_inputs().at(0);
double bias = node.get_attribute_value<double>("bias", 1); double alpha = node.get_attribute_value<double>("alpha", 1e-4);
size_t size = node.get_attribute_value<size_t>("size"); double beta = node.get_attribute_value<double>("beta", 0.75);
double bias = node.get_attribute_value<double>("bias", 1);
size_t size = node.get_attribute_value<size_t>("size");
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)}; return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,12 @@ namespace ngraph ...@@ -26,8 +26,12 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector lrn(const Node& node); namespace set_1
} // namespace op {
NodeVector lrn(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector matmul(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; inline NodeVector matmul(const Node& node)
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))}; {
} NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector max(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node); inline NodeVector max(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,12 +26,16 @@ namespace ngraph ...@@ -26,12 +26,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector max_pool(const Node& node) namespace set_1
{ {
return convpool::make_ng_pool<ngraph::op::MaxPool>(node); NodeVector max_pool(const Node& node)
} {
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,17 +26,21 @@ namespace ngraph ...@@ -26,17 +26,21 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/** namespace set_1
* @brief Convert ONNX MaxPool operation to an nGraph node. {
* /**
* @param node The ONNX node object representing this operation. * @brief Convert ONNX MaxPool operation to an nGraph node.
* *
* @return The vector containing Ngraph nodes producing output of ONNX MaxPool * @param node The ONNX node object representing this operation.
* operation. *
*/ * @return The vector containing Ngraph nodes producing output of ONNX MaxPool
NodeVector max_pool(const Node& node); * operation.
*/
} // namespace op NodeVector max_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,21 +27,25 @@ namespace ngraph ...@@ -27,21 +27,25 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector mean(const Node& node) namespace set_1
{ {
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front(); NodeVector mean(const Node& node)
auto shape = sum->get_shape(); {
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as sum // Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create( auto count = ngraph::op::Constant::create(
sum->get_element_type(), sum->get_element_type(),
shape, shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size())); std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
return {sum / count}; return {sum / count};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,9 +26,13 @@ namespace ngraph ...@@ -26,9 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector mean(const Node& node); namespace set_1
{
NodeVector mean(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector min(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node); inline NodeVector min(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector mul(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector mul(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,8 +27,12 @@ namespace ngraph ...@@ -27,8 +27,12 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; } namespace set_1
} // namespace op {
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_not(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))}; inline NodeVector logical_not(const Node& node)
} {
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_or(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_or(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector pow(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector pow(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -37,31 +37,35 @@ namespace ngraph ...@@ -37,31 +37,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector prelu(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector prelu(const Node& node)
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{ {
auto params = numpy_style_broadcast_for_binary_operation(slope, data); NodeVector ng_inputs{node.get_ng_inputs()};
slope = params.at(0); auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
} }
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector prelu(const Node& node); namespace set_1
} // namespace op {
NodeVector prelu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,18 +31,22 @@ namespace ngraph ...@@ -31,18 +31,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reciprocal(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector reciprocal(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {one_node / data}; return {one_node / data};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reciprocal(const Node& node); namespace set_1
} // namespace op {
NodeVector reciprocal(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,29 +32,35 @@ namespace ngraph ...@@ -32,29 +32,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reduce_mean(const Node& node) namespace set_1
{ {
auto input_shape = node.get_ng_inputs().at(0)->get_shape(); NodeVector reduce_mean(const Node& node)
auto reduction_axes = reduction::detail::get_reduction_axes(node); {
std::size_t elem_count_product = auto input_shape = node.get_ng_inputs().at(0)->get_shape();
std::accumulate(std::begin(reduction_axes), auto reduction_axes = reduction::detail::get_reduction_axes(node);
std::end(reduction_axes), std::size_t elem_count_product =
1UL, std::accumulate(std::begin(reduction_axes),
[&input_shape](const std::size_t& a, const std::size_t& b) { std::end(reduction_axes),
return a * input_shape.at(b); 1UL,
}); [&input_shape](const std::size_t& a, const std::size_t& b) {
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>( return a * input_shape.at(b);
node, node.get_ng_inputs().at(0)); });
auto const_node = std::make_shared<ngraph::op::Constant>( auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
sum_node->get_element_type(), node, node.get_ng_inputs().at(0));
Shape{}, auto const_node = std::make_shared<ngraph::op::Constant>(
std::vector<std::size_t>{elem_count_product}); sum_node->get_element_type(),
Shape{},
auto broadcasted_const_node = std::vector<std::size_t>{elem_count_product});
make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)}; auto broadcasted_const_node =
} make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
} // namespace op }
} // namespace onnx_import
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector relu(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; inline NodeVector relu(const Node& node)
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))}; {
} NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -33,39 +33,45 @@ namespace ngraph ...@@ -33,39 +33,45 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reshape(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector reshape(const Node& node)
auto data = ng_inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {}); auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input. // If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2) if (output_shape.empty() && ng_inputs.size() == 2)
{ {
// Currently only support Constant node. // Currently only support Constant node.
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant") ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant."; << "doesn't support shape input of other type than Constant.";
auto output_shape_node = auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1)); std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>(); output_shape = output_shape_node->get_vector<std::size_t>();
} }
// Do nothing if there is no shape argument nor second node input. // Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty()) else if (output_shape.empty())
{ {
return {data}; return {data};
}
output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
} }
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape); } // namespace set_1
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,16 +26,20 @@ namespace ngraph ...@@ -26,16 +26,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// namespace set_1
/// \brief Reshape the input tensor similar to numpy.reshape. {
/// ///
/// \param[in] node The ONNX node representing this operation. /// \brief Reshape the input tensor similar to numpy.reshape.
/// ///
/// \return Ngraph node representing this operation. /// \param[in] node The ONNX node representing this operation.
/// ///
NodeVector reshape(const Node& node); /// \return Ngraph node representing this operation.
///
} // namespace op NodeVector reshape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -41,32 +41,41 @@ namespace ngraph ...@@ -41,32 +41,41 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector selu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector selu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625); {
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875); auto data = node.get_ng_inputs().at(0);
double alpha =
node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma =
node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> gamma_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma}); std::make_shared<ngraph::op::Constant>(
gamma_node = make_broadcast_node(gamma_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node * return {gamma_node * (std::make_shared<ngraph::op::Maximum>(data, zero_node) +
(std::make_shared<ngraph::op::Maximum>(data, zero_node) + alpha_node * std::make_shared<ngraph::op::Exp>(
alpha_node * std::make_shared<ngraph::op::Exp>( std::make_shared<ngraph::op::Minimum>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - data, zero_node)) -
alpha_node)}; alpha_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector selu(const Node& node); namespace set_1
} // namespace op {
NodeVector selu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,16 +30,20 @@ namespace ngraph ...@@ -30,16 +30,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector shape(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector shape(const Node& node)
auto data_shape = data->get_shape(); {
auto data = node.get_ng_inputs().at(0);
auto data_shape = data->get_shape();
return {std::make_shared<ngraph::op::Constant>( return {std::make_shared<ngraph::op::Constant>(
ngraph::element::i64, Shape{data_shape.size()}, data_shape)}; ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector shape(const Node& node); namespace set_1
} // namespace op {
NodeVector shape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sigmoid(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))}; inline NodeVector sigmoid(const Node& node)
} {
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,32 +35,37 @@ namespace ngraph ...@@ -35,32 +35,37 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector slice(const Node& node) namespace set_1
{ {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0); NodeVector slice(const Node& node)
Shape data_shape = data->get_shape(); {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts"); auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends"); auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
auto axes = node.get_attribute_value<std::vector<int64_t>>( auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size())); "axes", common::get_monotonic_range<int64_t>(data_shape.size()));
Shape lower_bounds(data_shape.size()); Shape lower_bounds(data_shape.size());
Shape upper_bounds = data_shape; Shape upper_bounds = data_shape;
for (auto idx = 0; idx < axes.size(); ++idx) for (auto idx = 0; idx < axes.size(); ++idx)
{ {
size_t axis = axes.at(idx); size_t axis = axes.at(idx);
lower_bounds.at(axis) = lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis)); get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis)); upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
} }
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector slice(const Node& node); namespace set_1
} // namespace op {
NodeVector slice(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,31 +27,35 @@ namespace ngraph ...@@ -27,31 +27,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softmax(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector softmax(const Node& node)
auto data = inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1); int axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0) if (axis < 0)
{ {
axis = data_shape.size() + axis; axis = data_shape.size() + axis;
} }
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size()) ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis << "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range."; << " is out of input tensor dimensions range.";
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
// create vector of capacity data_dimensions - axis_divider position } // namespace set_1
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softmax(const Node& node); namespace set_1
} // namespace op {
NodeVector softmax(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -33,19 +33,23 @@ namespace ngraph ...@@ -33,19 +33,23 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softplus(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector softplus(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) + return {std::make_shared<ngraph::op::Log>(
one_node)}; std::make_shared<ngraph::op::Exp>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softplus(const Node& node); namespace set_1
} // namespace op {
NodeVector softplus(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -33,18 +33,22 @@ namespace ngraph ...@@ -33,18 +33,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softsign(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector softsign(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)}; return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softsign(const Node& node); namespace set_1
} // namespace op {
NodeVector softsign(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -80,79 +80,84 @@ namespace ngraph ...@@ -80,79 +80,84 @@ namespace ngraph
namespace op namespace op
{ {
namespace detail namespace set_1
{ {
template <typename T> namespace detail
inline T get_valid_array_index(T left, T right)
{ {
return (left >= 0) ? std::min(left, right) template <typename T>
: std::max(static_cast<T>(0), right + left); inline T get_valid_array_index(T left, T right)
} {
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice> inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node, make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes, std::vector<std::size_t> axes,
std::vector<std::size_t> starts, std::vector<std::size_t> starts,
std::vector<std::size_t> ends) std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{ {
std::size_t axis{axes.at(index)}; std::vector<std::size_t> upper_bounds{node->get_shape()};
lower_bounds.at(axis) = std::vector<std::size_t> lower_bounds(upper_bounds.size());
get_valid_array_index(starts.at(index), node->get_shape().at(axis)); for (std::size_t index{0}; index < axes.size(); ++index)
upper_bounds.at(axis) = {
get_valid_array_index(ends.at(index), node->get_shape().at(axis)); std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
} }
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace detail } // namespace detail
NodeVector split(const Node& node) NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{ {
if (length_axis_to_split % count_outputs) std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{ {
throw error::op::split::Parts{ axis_to_split = input->get_shape().size() + axis;
node.get_name(), count_outputs, length_axis_to_split}; }
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
} }
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0}; std::size_t start_index{0};
NodeVector outputs; NodeVector outputs;
for (const auto& length_part : length_parts) for (const auto& length_part : length_parts)
{ {
std::size_t end_index{start_index + length_part}; std::size_t end_index{start_index + length_part};
outputs.push_back( outputs.push_back(detail::make_ng_slice(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index})); input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index; start_index = end_index;
}
return outputs;
} }
return outputs;
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector split(const Node& node); namespace set_1
} // namespace op {
NodeVector split(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sqrt(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))}; inline NodeVector sqrt(const Node& node)
} {
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,38 +30,42 @@ namespace ngraph ...@@ -30,38 +30,42 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector squeeze(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector squeeze(const Node& node)
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
{ {
for (auto index = 0; index < data_shape.size(); ++index) NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
{ {
if (data_shape.at(index) == 1) for (auto index = 0; index < data_shape.size(); ++index)
{ {
axes.push_back(index); if (data_shape.at(index) == 1)
{
axes.push_back(index);
}
} }
} }
}
std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>()); std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
data_shape.erase(std::next(std::begin(data_shape), axis)); data_shape.erase(std::next(std::begin(data_shape), axis));
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector squeeze(const Node& node); namespace set_1
} // namespace op {
NodeVector squeeze(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sub(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector sub(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sum(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Add>(node); inline NodeVector sum(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector tanh(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))}; inline NodeVector tanh(const Node& node)
} {
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -36,22 +36,27 @@ namespace ngraph ...@@ -36,22 +36,27 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector thresholded_relu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector thresholded_relu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1.0); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>( auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node), std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type()); data->get_element_type());
return {data * data_map}; return {data * data_map};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector thresholded_relu(const Node& node); namespace set_1
} // namespace op {
NodeVector thresholded_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,17 +28,22 @@ namespace ngraph ...@@ -28,17 +28,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector transpose(const Node& node) namespace set_1
{ {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0); NodeVector transpose(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {}); auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data) return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)}; : reshape::reorder_axes(data, permute_axes)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector transpose(const Node& node); namespace set_1
} // namespace op {
NodeVector transpose(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,32 +28,36 @@ namespace ngraph ...@@ -28,32 +28,36 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector unsqueeze(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector unsqueeze(const Node& node)
auto data = inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector inputs{node.get_ng_inputs()};
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes"); auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory."; ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size()) ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid."; << "provided 'axes' attribute is not valid.";
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
data_shape.insert(std::next(std::begin(data_shape), axis), 1); return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector unsqueeze(const Node& node); namespace set_1
} // namespace op {
NodeVector unsqueeze(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -30,20 +30,24 @@ namespace ngraph ...@@ -30,20 +30,24 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_xor(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_xor(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
auto left = ng_inputs.at(0); NodeVector ng_inputs{
auto not_left = std::make_shared<ngraph::op::Not>(left); numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto right = ng_inputs.at(1); auto left = ng_inputs.at(0);
auto not_right = std::make_shared<ngraph::op::Not>(right); auto not_left = std::make_shared<ngraph::op::Not>(left);
return {std::make_shared<ngraph::op::Or>( auto right = ng_inputs.at(1);
std::make_shared<ngraph::op::And>(left, not_right), auto not_right = std::make_shared<ngraph::op::Not>(right);
std::make_shared<ngraph::op::And>(not_left, right))}; return {std::make_shared<ngraph::op::Or>(
} std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
} // namespace op }
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment