Commit 9debc0bc authored by Artur Wojcik's avatar Artur Wojcik Committed by Scott Cyphers

[ONNX] Add support to OperatorSet (Part 1) - namespace change only (#1812)

* onnx: add register operator macro
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>

* onnx: add set information to 'op' namespace
Signed-off-by: 's avatarArtur Wojcik <artur.wojcik@intel.com>
parent 1d5f047a
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector abs(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
inline NodeVector abs(const Node& node)
{
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector add(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector add(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector logical_and(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector logical_and(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,12 +26,16 @@ namespace ngraph
{
namespace op
{
NodeVector average_pool(const Node& node)
namespace set_1
{
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
NodeVector average_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,17 +26,21 @@ namespace ngraph
{
namespace op
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool
* operation.
*/
NodeVector average_pool(const Node& node);
} // namespace op
namespace set_1
{
/**
* @brief Convert ONNX AveragePool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool
* operation.
*/
NodeVector average_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,38 +28,42 @@ namespace ngraph
{
namespace op
{
NodeVector batch_norm(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0);
auto scale = inputs.at(1);
auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
NodeVector batch_norm(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0);
auto scale = inputs.at(1);
auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false;
int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false;
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5)
{
mean = inputs.at(3);
var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)};
if (inputs.size() >= 5)
{
mean = inputs.at(3);
var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
} // namespace onnx_import
} // namespace ngraph
} // namespace ngraph
......@@ -24,9 +24,14 @@ namespace ngraph
{
namespace op
{
NodeVector batch_norm(const Node& node);
} // namespace op
namespace set_1
{
NodeVector batch_norm(const Node& node);
} // namespace onnx_import
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -30,34 +30,40 @@ namespace ngraph
{
namespace op
{
NodeVector cast(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type;
switch (target_type)
NodeVector cast(const Node& node)
{
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
auto data = node.get_ng_inputs().at(0);
int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type;
switch (target_type)
{
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector cast(const Node& node);
} // namespace op
namespace set_1
{
NodeVector cast(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector ceil(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
inline NodeVector ceil(const Node& node)
{
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -34,30 +34,37 @@ namespace ngraph
{
namespace op
{
NodeVector clip(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
NodeVector clip(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value =
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest());
double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value = node.get_attribute_value<double>(
"min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
return {std::make_shared<ngraph::op::Minimum>(
max_value_node,
std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector clip(const Node& node);
} // namespace op
namespace set_1
{
NodeVector clip(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -24,16 +24,20 @@ namespace ngraph
{
namespace op
{
NodeVector concat(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
NodeVector concat(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
}
return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
}
} // namespace op
} // namespace set_1
} // namespace onnx_import
} //namespace op
} // namespace ngraph
} // namespace onnx_import
} // namespace ngraph
......@@ -26,9 +26,14 @@ namespace ngraph
{
namespace op
{
NodeVector concat(const Node& node);
} // namespace op
namespace set_1
{
NodeVector concat(const Node& node);
} // namespace onnx_import
} // namespace set_1
} // namespace ngraph
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -25,96 +25,101 @@ namespace ngraph
{
namespace op
{
namespace
namespace set_1
{
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
namespace
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
template <typename T>
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant(const Tensor& tensor)
{
throw error::tensor::unsupported_data_type{tensor};
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{
return __make_ng_constant<float>(element::f32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{
return __make_ng_constant<double>(element::f64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{
return __make_ng_constant<int32_t>(element::i32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{
return __make_ng_constant<int64_t>(element::i64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{
return __make_ng_constant<uint32_t>(element::u32, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
template <>
inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{
return __make_ng_constant<uint64_t>(element::u64, tensor);
}
inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor)
{
inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor)
{
#define MAKE_NG_CONSTANT(data_type_) \
case data_type_: return make_ng_constant<data_type_>(tensor)
switch (tensor.get_type())
{
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
switch (tensor.get_type())
{
MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor};
}
}
}
}
NodeVector constant(const onnx_import::Node& node)
{
return {make_constant(node.get_attribute_value<Tensor>("value"))};
}
NodeVector constant(const onnx_import::Node& node)
{
return {make_constant(node.get_attribute_value<Tensor>("value"))};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -26,9 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector constant(const Node& node);
namespace set_1
{
NodeVector constant(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
This diff is collapsed.
......@@ -26,15 +26,19 @@ namespace ngraph
{
namespace op
{
/// \brief Performs ONNX Conv operation.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation.
NodeVector conv(const Node& node);
} // namespace op
namespace set_1
{
/// \brief Performs ONNX Conv operation.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation.
NodeVector conv(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector div(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector div(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -38,26 +38,33 @@ namespace ngraph
{
namespace op
{
NodeVector elu(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
NodeVector elu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node *
std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector elu(const Node& node);
} // namespace op
namespace set_1
{
NodeVector elu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector equal(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector equal(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector exp(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
inline NodeVector exp(const Node& node)
{
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -25,19 +25,23 @@ namespace ngraph
{
namespace op
{
NodeVector flatten(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
NodeVector flatten(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid.";
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid.";
return {reshape::flatten(data, axis)};
}
return {reshape::flatten(data, axis)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,8 +28,13 @@ namespace ngraph
{
namespace op
{
NodeVector flatten(const Node& node);
} // namespace op
namespace set_1
{
NodeVector flatten(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector floor(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
inline NodeVector floor(const Node& node)
{
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -32,50 +32,58 @@ namespace ngraph
{
namespace op
{
NodeVector gemm(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
NodeVector gemm(const Node& node)
{
input_b = reshape::transpose(input_b);
NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{
input_b = reshape::transpose(input_b);
}
// code from python not implemented in c++ yet.
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
// code from python not implemented in c++ yet.
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -28,8 +28,13 @@ namespace ngraph
{
namespace op
{
NodeVector gemm(const Node& node);
} // namespace op
namespace set_1
{
NodeVector gemm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,19 @@ namespace ngraph
{
namespace op
{
inline NodeVector greater(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector greater(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -35,36 +35,43 @@ namespace ngraph
{
namespace op
{
NodeVector hard_sigmoid(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
NodeVector hard_sigmoid(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5);
double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(
zero_node,
std::make_shared<ngraph::op::Minimum>(one_node,
alpha_node * data + beta_node))};
}
return {std::make_shared<ngraph::op::Maximum>(
zero_node,
std::make_shared<ngraph::op::Minimum>(one_node,
alpha_node * data + beta_node))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector hard_sigmoid(const Node& node);
} // namespace op
namespace set_1
{
NodeVector hard_sigmoid(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,15 @@ namespace ngraph
{
namespace op
{
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; }
} // namespace op
namespace set_1
{
inline NodeVector identity(const Node& node)
{
return {node.get_ng_inputs().at(0)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -38,21 +38,26 @@ namespace ngraph
{
namespace op
{
NodeVector leaky_relu(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
NodeVector leaky_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector leaky_relu(const Node& node);
} // namespace op
namespace set_1
{
NodeVector leaky_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector less(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector less(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector log(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
inline NodeVector log(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -31,13 +31,17 @@ namespace ngraph
{
namespace op
{
inline NodeVector log_softmax(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
inline NodeVector log_softmax(const Node& node)
{
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace op
} // namespace set_1
} // namespace onnx_import
} //namespace op
} // namespace ngraph
} // namespace onnx_import
} // namespace ngraph
......@@ -27,18 +27,22 @@ namespace ngraph
{
namespace op
{
NodeVector lrn(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1e-4);
double beta = node.get_attribute_value<double>("beta", 0.75);
double bias = node.get_attribute_value<double>("bias", 1);
size_t size = node.get_attribute_value<size_t>("size");
NodeVector lrn(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1e-4);
double beta = node.get_attribute_value<double>("beta", 0.75);
double bias = node.get_attribute_value<double>("bias", 1);
size_t size = node.get_attribute_value<size_t>("size");
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
}
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,12 @@ namespace ngraph
{
namespace op
{
NodeVector lrn(const Node& node);
} // namespace op
namespace set_1
{
NodeVector lrn(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,13 +27,17 @@ namespace ngraph
{
namespace op
{
inline NodeVector matmul(const Node& node)
namespace set_1
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector matmul(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,12 +28,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector max(const Node& node)
namespace set_1
{
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
inline NodeVector max(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,12 +26,16 @@ namespace ngraph
{
namespace op
{
NodeVector max_pool(const Node& node)
namespace set_1
{
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
NodeVector max_pool(const Node& node)
{
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,17 +26,21 @@ namespace ngraph
{
namespace op
{
/**
* @brief Convert ONNX MaxPool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX MaxPool
* operation.
*/
NodeVector max_pool(const Node& node);
} // namespace op
namespace set_1
{
/**
* @brief Convert ONNX MaxPool operation to an nGraph node.
*
* @param node The ONNX node object representing this operation.
*
* @return The vector containing Ngraph nodes producing output of ONNX MaxPool
* operation.
*/
NodeVector max_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,21 +27,25 @@ namespace ngraph
{
namespace op
{
NodeVector mean(const Node& node)
namespace set_1
{
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
NodeVector mean(const Node& node)
{
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create(
sum->get_element_type(),
shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
// Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create(
sum->get_element_type(),
shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
return {sum / count};
}
return {sum / count};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,9 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector mean(const Node& node);
namespace set_1
{
NodeVector mean(const Node& node);
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,12 +28,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector min(const Node& node)
namespace set_1
{
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
inline NodeVector min(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,19 @@ namespace ngraph
{
namespace op
{
inline NodeVector mul(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector mul(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,8 +27,12 @@ namespace ngraph
{
namespace op
{
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace op
namespace set_1
{
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,12 +28,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector logical_not(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
inline NodeVector logical_not(const Node& node)
{
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector logical_or(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector logical_or(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,14 +28,18 @@ namespace ngraph
{
namespace op
{
inline NodeVector pow(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector pow(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -37,31 +37,35 @@ namespace ngraph
{
namespace op
{
NodeVector prelu(const Node& node)
namespace set_1
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
NodeVector prelu(const Node& node)
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector prelu(const Node& node);
} // namespace op
namespace set_1
{
NodeVector prelu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -31,18 +31,22 @@ namespace ngraph
{
namespace op
{
NodeVector reciprocal(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
NodeVector reciprocal(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {one_node / data};
}
return {one_node / data};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector reciprocal(const Node& node);
} // namespace op
namespace set_1
{
NodeVector reciprocal(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -32,29 +32,35 @@ namespace ngraph
{
namespace op
{
NodeVector reduce_mean(const Node& node)
namespace set_1
{
auto input_shape = node.get_ng_inputs().at(0)->get_shape();
auto reduction_axes = reduction::detail::get_reduction_axes(node);
std::size_t elem_count_product =
std::accumulate(std::begin(reduction_axes),
std::end(reduction_axes),
1UL,
[&input_shape](const std::size_t& a, const std::size_t& b) {
return a * input_shape.at(b);
});
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
node, node.get_ng_inputs().at(0));
auto const_node = std::make_shared<ngraph::op::Constant>(
sum_node->get_element_type(),
Shape{},
std::vector<std::size_t>{elem_count_product});
auto broadcasted_const_node =
make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
}
} // namespace op
} // namespace onnx_import
NodeVector reduce_mean(const Node& node)
{
auto input_shape = node.get_ng_inputs().at(0)->get_shape();
auto reduction_axes = reduction::detail::get_reduction_axes(node);
std::size_t elem_count_product =
std::accumulate(std::begin(reduction_axes),
std::end(reduction_axes),
1UL,
[&input_shape](const std::size_t& a, const std::size_t& b) {
return a * input_shape.at(b);
});
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
node, node.get_ng_inputs().at(0));
auto const_node = std::make_shared<ngraph::op::Constant>(
sum_node->get_element_type(),
Shape{},
std::vector<std::size_t>{elem_count_product});
auto broadcasted_const_node =
make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -27,13 +27,17 @@ namespace ngraph
{
namespace op
{
inline NodeVector relu(const Node& node)
namespace set_1
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
inline NodeVector relu(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -33,39 +33,45 @@ namespace ngraph
{
namespace op
{
NodeVector reshape(const Node& node)
namespace set_1
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
NodeVector reshape(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {});
auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
{
// Currently only support Constant node.
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant.";
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
{
// Currently only support Constant node.
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant.";
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
{
return {data};
auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>();
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
{
return {data};
}
output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -26,16 +26,20 @@ namespace ngraph
{
namespace op
{
///
/// \brief Reshape the input tensor similar to numpy.reshape.
///
/// \param[in] node The ONNX node representing this operation.
///
/// \return Ngraph node representing this operation.
///
NodeVector reshape(const Node& node);
} // namespace op
namespace set_1
{
///
/// \brief Reshape the input tensor similar to numpy.reshape.
///
/// \param[in] node The ONNX node representing this operation.
///
/// \return Ngraph node representing this operation.
///
NodeVector reshape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -41,32 +41,41 @@ namespace ngraph
{
namespace op
{
NodeVector selu(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
NodeVector selu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha =
node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma =
node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node *
(std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node)};
}
return {gamma_node * (std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(
data, zero_node)) -
alpha_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector selu(const Node& node);
} // namespace op
namespace set_1
{
NodeVector selu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -30,16 +30,20 @@ namespace ngraph
{
namespace op
{
NodeVector shape(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
auto data_shape = data->get_shape();
NodeVector shape(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto data_shape = data->get_shape();
return {std::make_shared<ngraph::op::Constant>(
ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
}
return {std::make_shared<ngraph::op::Constant>(
ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector shape(const Node& node);
} // namespace op
namespace set_1
{
NodeVector shape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector sigmoid(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
inline NodeVector sigmoid(const Node& node)
{
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -35,32 +35,37 @@ namespace ngraph
{
namespace op
{
NodeVector slice(const Node& node)
namespace set_1
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
NodeVector slice(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size()));
auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size()));
Shape lower_bounds(data_shape.size());
Shape upper_bounds = data_shape;
Shape lower_bounds(data_shape.size());
Shape upper_bounds = data_shape;
for (auto idx = 0; idx < axes.size(); ++idx)
{
size_t axis = axes.at(idx);
lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis));
for (auto idx = 0; idx < axes.size(); ++idx)
{
size_t axis = axes.at(idx);
lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector slice(const Node& node);
} // namespace op
namespace set_1
{
NodeVector slice(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -27,31 +27,35 @@ namespace ngraph
{
namespace op
{
NodeVector softmax(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
NodeVector softmax(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
int axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0)
{
axis = data_shape.size() + axis;
}
if (axis < 0)
{
axis = data_shape.size() + axis;
}
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
} // namespace onnx_import
} // namespace ngraph
} // namespace ngraph
......@@ -26,9 +26,14 @@ namespace ngraph
{
namespace op
{
NodeVector softmax(const Node& node);
} // namespace op
namespace set_1
{
NodeVector softmax(const Node& node);
} // namespace onnx_import
} // namespace set_1
} // namespace ngraph
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -33,19 +33,23 @@ namespace ngraph
{
namespace op
{
NodeVector softplus(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
NodeVector softplus(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) +
one_node)};
}
return {std::make_shared<ngraph::op::Log>(
std::make_shared<ngraph::op::Exp>(data) + one_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector softplus(const Node& node);
} // namespace op
namespace set_1
{
NodeVector softplus(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -33,18 +33,22 @@ namespace ngraph
{
namespace op
{
NodeVector softsign(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
NodeVector softsign(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
}
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector softsign(const Node& node);
} // namespace op
namespace set_1
{
NodeVector softsign(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -80,79 +80,84 @@ namespace ngraph
namespace op
{
namespace detail
namespace set_1
{
template <typename T>
inline T get_valid_array_index(T left, T right)
namespace detail
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
template <typename T>
inline T get_valid_array_index(T left, T right)
{
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes,
std::vector<std::size_t> starts,
std::vector<std::size_t> ends)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
}
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace detail
} // namespace detail
NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
NodeVector split(const Node& node)
{
if (length_axis_to_split % count_outputs)
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
std::size_t start_index{0};
NodeVector outputs;
for (const auto& length_part : length_parts)
{
std::size_t end_index{start_index + length_part};
outputs.push_back(detail::make_ng_slice(
input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
return outputs;
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector split(const Node& node);
} // namespace op
namespace set_1
{
NodeVector split(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector sqrt(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
inline NodeVector sqrt(const Node& node)
{
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -30,38 +30,42 @@ namespace ngraph
{
namespace op
{
NodeVector squeeze(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
NodeVector squeeze(const Node& node)
{
for (auto index = 0; index < data_shape.size(); ++index)
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
{
if (data_shape.at(index) == 1)
for (auto index = 0; index < data_shape.size(); ++index)
{
axes.push_back(index);
if (data_shape.at(index) == 1)
{
axes.push_back(index);
}
}
}
}
std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>());
std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes)
{
data_shape.erase(std::next(std::begin(data_shape), axis));
for (auto axis : axes)
{
data_shape.erase(std::next(std::begin(data_shape), axis));
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
} // namespace onnx_import
} // namespace ngraph
} // namespace ngraph
......@@ -26,9 +26,14 @@ namespace ngraph
{
namespace op
{
NodeVector squeeze(const Node& node);
} // namespace op
namespace set_1
{
NodeVector squeeze(const Node& node);
} // namespace onnx_import
} // namespace set_1
} // namespace ngraph
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -28,14 +28,19 @@ namespace ngraph
{
namespace op
{
inline NodeVector sub(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
inline NodeVector sub(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,12 +28,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector sum(const Node& node)
namespace set_1
{
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
inline NodeVector sum(const Node& node)
{
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -29,12 +29,16 @@ namespace ngraph
{
namespace op
{
inline NodeVector tanh(const Node& node)
namespace set_1
{
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
inline NodeVector tanh(const Node& node)
{
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -36,22 +36,27 @@ namespace ngraph
{
namespace op
{
NodeVector thresholded_relu(const Node& node)
namespace set_1
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
NodeVector thresholded_relu(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type());
return {data * data_map};
}
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type());
return {data * data_map};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector thresholded_relu(const Node& node);
} // namespace op
namespace set_1
{
NodeVector thresholded_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,17 +28,22 @@ namespace ngraph
{
namespace op
{
NodeVector transpose(const Node& node)
namespace set_1
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
NodeVector transpose(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {});
auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)};
}
return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)};
}
} // namespace op
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -26,8 +26,13 @@ namespace ngraph
{
namespace op
{
NodeVector transpose(const Node& node);
} // namespace op
namespace set_1
{
NodeVector transpose(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
......@@ -28,32 +28,36 @@ namespace ngraph
{
namespace op
{
NodeVector unsqueeze(const Node& node)
namespace set_1
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
NodeVector unsqueeze(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes)
{
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid.";
for (auto axis : axes)
{
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid.";
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
}
} // namespace set_1
} // namespace op
} //namespace op
} // namespace onnx_import
} // namespace onnx_import
} // namespace ngraph
} // namespace ngraph
......@@ -26,9 +26,14 @@ namespace ngraph
{
namespace op
{
NodeVector unsqueeze(const Node& node);
} // namespace op
namespace set_1
{
NodeVector unsqueeze(const Node& node);
} // namespace onnx_import
} // namespace set_1
} // namespace ngraph
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -30,20 +30,24 @@ namespace ngraph
{
namespace op
{
inline NodeVector logical_xor(const Node& node)
namespace set_1
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
auto not_right = std::make_shared<ngraph::op::Not>(right);
return {std::make_shared<ngraph::op::Or>(
std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
}
} // namespace op
inline NodeVector logical_xor(const Node& node)
{
NodeVector ng_inputs{
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto left = ng_inputs.at(0);
auto not_left = std::make_shared<ngraph::op::Not>(left);
auto right = ng_inputs.at(1);
auto not_right = std::make_shared<ngraph::op::Not>(right);
return {std::make_shared<ngraph::op::Or>(
std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment