Commit d07e38e0 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[Fused Ops] Add fused version of Elu (#2797)

* Add fused version of Elu op

* Refactor ONNX importer prelu function to use fused op

* Style check

* Add docstrings

* Move make_constant_node to op/util

* Use make_constant_node helper

* Remove unneeded std:: prefixes

* Remove make_constant_node function, use builder::make_constant

* Remove redundant includes

* Add Elu to serializer

* Add Elu tests

* Add Elu tests to type prop

* Add Elu to list of ops unsupported on iGPU

* Add Elu to list of ops unsupported on iGPU

* Disable tests in iGPU manifest
parent 02b04376
...@@ -259,6 +259,8 @@ set (SRC ...@@ -259,6 +259,8 @@ set (SRC
op/topk.hpp op/topk.hpp
op/fused/conv_fused.cpp op/fused/conv_fused.cpp
op/fused/conv_fused.hpp op/fused/conv_fused.hpp
op/fused/elu.cpp
op/fused/elu.hpp
op/fused/prelu.cpp op/fused/prelu.cpp
op/fused/prelu.hpp op/fused/prelu.hpp
op/util/arithmetic_reduction.cpp op/util/arithmetic_reduction.cpp
......
...@@ -21,12 +21,12 @@ ...@@ -21,12 +21,12 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "quantize_linear.hpp" #include "quantize_linear.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,8 +48,8 @@ namespace ngraph ...@@ -48,8 +48,8 @@ namespace ngraph
} }
else else
{ {
zero_point = common::make_constant_node( zero_point =
x->get_element_type(), Shape{}, std::vector<std::uint8_t>{0}); ngraph::builder::make_constant(x->get_element_type(), Shape{}, 0);
} }
Shape y_scale_shape = x_scale->get_shape(); Shape y_scale_shape = x_scale->get_shape();
......
...@@ -17,16 +17,8 @@ ...@@ -17,16 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "elu.hpp" #include "elu.hpp"
...@@ -46,18 +38,8 @@ namespace ngraph ...@@ -46,18 +38,8 @@ namespace ngraph
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = return NodeVector{std::make_shared<ngraph::op::Elu>(data, alpha_node)};
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node *
std::make_shared<ngraph::op::Exp>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "lstm.hpp" #include "lstm.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -49,7 +50,6 @@ ...@@ -49,7 +50,6 @@
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
#include "utils/rnn/activation_functions.hpp" #include "utils/rnn/activation_functions.hpp"
...@@ -166,10 +166,8 @@ namespace ngraph ...@@ -166,10 +166,8 @@ namespace ngraph
} }
else else
{ {
m_map[LSTMInput::LSTM_INPUT_B] = common::make_constant_node<float>( m_map[LSTMInput::LSTM_INPUT_B] = ngraph::builder::make_constant<float>(
element::f32, element::f32, {num_directions, 2 * gates_count * hidden_size}, 0.f);
{num_directions, 2 * gates_count * hidden_size},
{0.f});
} }
// The lengths of the sequences in a batch. Shape [batch_size] // The lengths of the sequences in a batch. Shape [batch_size]
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null()) if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
...@@ -191,8 +189,9 @@ namespace ngraph ...@@ -191,8 +189,9 @@ namespace ngraph
} }
else else
{ {
m_map[LSTMInput::LSTM_INPUT_INIT_H] = common::make_constant_node<float>( m_map[LSTMInput::LSTM_INPUT_INIT_H] =
element::f32, {num_directions, batch_size, hidden_size}, {0.f}); ngraph::builder::make_constant<float>(
element::f32, {num_directions, batch_size, hidden_size}, 0.f);
} }
// The initial value of the cell. Shape [num_directions, batch_size, hidden_size] // The initial value of the cell. Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null()) if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
...@@ -201,8 +200,9 @@ namespace ngraph ...@@ -201,8 +200,9 @@ namespace ngraph
} }
else else
{ {
m_map[LSTMInput::LSTM_INPUT_INIT_C] = common::make_constant_node<float>( m_map[LSTMInput::LSTM_INPUT_INIT_C] =
element::f32, {num_directions, batch_size, hidden_size}, {0.f}); ngraph::builder::make_constant<float>(
element::f32, {num_directions, batch_size, hidden_size}, 0.f);
} }
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size] // The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null()) if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
...@@ -211,10 +211,8 @@ namespace ngraph ...@@ -211,10 +211,8 @@ namespace ngraph
} }
else else
{ {
m_map[LSTMInput::LSTM_INPUT_P] = common::make_constant_node<float>( m_map[LSTMInput::LSTM_INPUT_P] = ngraph::builder::make_constant<float>(
element::f32, element::f32, {num_directions, peepholes_count * hidden_size}, 0.f);
{num_directions, peepholes_count * hidden_size},
{0.f});
} }
} }
......
...@@ -19,17 +19,7 @@ ...@@ -19,17 +19,7 @@
#include <memory> #include <memory>
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "prelu.hpp" #include "prelu.hpp"
namespace ngraph namespace ngraph
...@@ -43,44 +33,9 @@ namespace ngraph ...@@ -43,44 +33,9 @@ namespace ngraph
NodeVector prelu(const Node& node) NodeVector prelu(const Node& node)
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0); const auto& data = ng_inputs.at(0);
auto data_shape = data->get_shape(); const auto& slope = ng_inputs.at(1);
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1); return {std::make_shared<ngraph::op::PRelu>(data, slope)};
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = ngraph::op::make_broadcast_node(slope, data->get_shape(), index);
}
else if (data_shape != slope_shape)
{
slope = ngraph::op::numpy_style_broadcast({slope, data})[0];
}
// x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = ngraph::op::make_broadcast_node(zero_node, data->get_shape());
std::shared_ptr<ngraph::Node> negative_map =
std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node),
data->get_element_type());
std::shared_ptr<ngraph::Node> positive_map =
std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node),
data->get_element_type());
slope = negative_map * slope + positive_map;
return {data * slope};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -63,36 +63,6 @@ namespace ngraph ...@@ -63,36 +63,6 @@ namespace ngraph
return range; return range;
} }
/// \brief Makes a Constant Ngraph node.
///
/// \param[in] type The node element type.
/// \param[in] shape The tensor data shape.
/// \param[in] data The data to initialize node with.
///
/// \tparam T Input data value type.
///
/// \return The Ngraph node representing Constant data.
///
template <typename T>
std::shared_ptr<ngraph::Node> make_constant_node(const ngraph::element::Type& type,
const ngraph::Shape& shape,
const std::vector<T>& data)
{
std::shared_ptr<ngraph::Node> node;
// Make constant node filled with single value.
if (data.size() == 1)
{
node = std::make_shared<ngraph::op::Constant>(type, ngraph::Shape{}, data);
node = ngraph::op::make_broadcast_node(node, shape);
}
else
{
node = std::make_shared<ngraph::op::Constant>(type, shape, data);
}
return node;
}
/// \brief Handle negative axis value. /// \brief Handle negative axis value.
/// ///
/// \param[in] axis The requested axis value. /// \param[in] axis The requested axis value.
......
...@@ -95,6 +95,7 @@ ...@@ -95,6 +95,7 @@
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
op::Elu::Elu(const shared_ptr<Node>& data, const shared_ptr<Node>& alpha)
: FusedOp("Elu", {data, alpha})
{
constructor_validate_and_infer_types();
}
NodeVector op::Elu::decompose_op() const
{
auto data = get_argument(0);
auto alpha_node = get_argument(1);
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0);
return {make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node *
make_shared<ngraph::op::Exp>(make_shared<ngraph::op::Minimum>(data, zero_node)) -
alpha_node};
}
shared_ptr<Node> op::Elu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Elu>(new_args.at(0), new_args.at(1));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Exponential Linear Unit
/// x < 0 => f(x) = alpha * (exp(x) - 1.)
/// x >= 0 => f(x) = x
///
class Elu : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs an Elu operation.
///
/// \param data Input tensor
/// \param alpha Multiplier for negative values
Elu(const std::shared_ptr<ngraph::Node>& data,
const std::shared_ptr<ngraph::Node>& alpha);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
// This collection contains one entry for each fused op. // This collection contains one entry for each fused op.
// //
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op) NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op) NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op) NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
......
...@@ -75,6 +75,7 @@ ...@@ -75,6 +75,7 @@
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
...@@ -1911,31 +1912,32 @@ shared_ptr<runtime::Executable> ...@@ -1911,31 +1912,32 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::BatchMatMul: case OP_TYPEID::BatchMatMul:
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
case OP_TYPEID::BroadcastLike: case OP_TYPEID::BroadcastLike:
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::DynPad:
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
case OP_TYPEID::DynSlice: case OP_TYPEID::DynSlice:
case OP_TYPEID::Elu:
case OP_TYPEID::EmbeddingLookup:
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
case OP_TYPEID::GatherND: case OP_TYPEID::GatherND:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough:
case OP_TYPEID::QuantizedAvgPool: case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedConvolutionBias: case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd: case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu: case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedDot: case OP_TYPEID::QuantizedDot:
case OP_TYPEID::QuantizedDotBias: case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedMaxPool: case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::ScalarConstantLike: case OP_TYPEID::ScalarConstantLike:
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
case OP_TYPEID::StopGradient: case OP_TYPEID::StopGradient:
case OP_TYPEID::Transpose: case OP_TYPEID::Transpose:
case OP_TYPEID::EmbeddingLookup:
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Passthrough:
case OP_TYPEID::DynPad:
case OP_TYPEID::PRelu:
default: default:
{ {
throw unsupported_op("Unsupported op '" + op->description() + throw unsupported_op("Unsupported op '" + op->description() +
......
...@@ -29,6 +29,10 @@ prelu ...@@ -29,6 +29,10 @@ prelu
prelu_shared_slope prelu_shared_slope
prelu_negative_slope prelu_negative_slope
group_conv group_conv
elu
elu_negative_alpha
space_to_depth
depth_to_space
# Unsupported extra padding modes # Unsupported extra padding modes
pad_edge_1d pad_edge_1d
......
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/conv_fused.hpp" #include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/prelu.hpp" #include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp" #include "ngraph/op/gather_nd.hpp"
...@@ -889,6 +890,11 @@ static shared_ptr<ngraph::Function> ...@@ -889,6 +890,11 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::DynSlice>(args[0], args[1], args[2], args[3]); node = make_shared<op::DynSlice>(args[0], args[1], args[2], args[3]);
break; break;
} }
case OP_TYPEID::Elu:
{
node = make_shared<op::Elu>(args[0], args[1]);
break;
}
case OP_TYPEID::EmbeddingLookup: case OP_TYPEID::EmbeddingLookup:
{ {
node = make_shared<op::EmbeddingLookup>(args[0], args[1]); node = make_shared<op::EmbeddingLookup>(args[0], args[1]);
...@@ -1722,6 +1728,8 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1722,6 +1728,8 @@ static json write(const Node& n, bool binary_constant_data)
} }
case OP_TYPEID::DynSlice: { break; case OP_TYPEID::DynSlice: { break;
} }
case OP_TYPEID::Elu: { break;
}
case OP_TYPEID::EmbeddingLookup: { break; case OP_TYPEID::EmbeddingLookup: { break;
} }
case OP_TYPEID::Equal: { break; case OP_TYPEID::Equal: { break;
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "util/all_close_f.hpp" #include "util/all_close_f.hpp"
#include "util/ndarray.hpp" #include "util/ndarray.hpp"
#include "util/random.hpp" #include "util/random.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -35,6 +36,36 @@ using namespace ngraph; ...@@ -35,6 +36,36 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}"; static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, elu)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto elu = make_shared<op::Elu>(A, B);
auto function = make_shared<Function>(NodeVector{elu}, ParameterVector{A, B});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<float>{-2.f, 3.f, -2.f, 1.f, -1.f, 0.f});
test_case.add_input(std::vector<float>{0.5f});
test_case.add_expected_output(
std::vector<float>{-0.432332358f, 3.f, -0.432332358f, 1.f, -0.316060279f, 0.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, elu_negative_alpha)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{});
auto elu = make_shared<op::Elu>(A, B);
auto function = make_shared<Function>(NodeVector{elu}, ParameterVector{A, B});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input(std::vector<float>{-2.f, 3.f, -2.f, 1.f, -1.f, 0.f});
test_case.add_input(std::vector<float>{-1.f});
test_case.add_expected_output(
std::vector<float>{0.864664717f, 3.f, 0.864664717f, 1.f, 0.632120559f, 0.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, prelu) NGRAPH_TEST(${BACKEND_NAME}, prelu)
{ {
Shape shape{3, 2}; Shape shape{3, 2};
......
...@@ -13371,6 +13371,16 @@ TEST(type_prop, prelu) ...@@ -13371,6 +13371,16 @@ TEST(type_prop, prelu)
ASSERT_EQ(prelu->get_shape(), prelu_shape); ASSERT_EQ(prelu->get_shape(), prelu_shape);
} }
TEST(type_prop, elu)
{
Shape data_shape{2, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto alpha = make_shared<op::Parameter>(element::f32, Shape{});
auto elu = make_shared<op::Elu>(data, alpha);
ASSERT_EQ(elu->get_element_type(), element::f32);
ASSERT_EQ(elu->get_shape(), data_shape);
}
TEST(type_prop, gather_no_axis) TEST(type_prop, gather_no_axis)
{ {
Shape params_shape{3, 2}; Shape params_shape{3, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment