Commit 5ed2c588 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[Spec][Fused] Adjust op Normalize (#3406)

* Changed name from Normalize to Normalize2

* Changed name from Normalize to Normalize2 in tests

* Changed name of normalize source files

* Removed across spatial and channel shared params

* Removed scale and introduced input axes

* Support for axes input was introduced

* Added possibility to choose method of bias using

* Clang style applied

* Code review remarks introduced

* Code review remarks introduced

* Added python script to generate normalize_l2 test data
parent 802bca81
......@@ -334,8 +334,8 @@ set (SRC
op/fused/lstm_cell.hpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/normalize.cpp
op/fused/normalize.hpp
op/fused/normalize_l2.cpp
op/fused/normalize_l2.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/rnn_cell.cpp
......
......@@ -19,6 +19,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/power.hpp"
......@@ -97,8 +98,10 @@ namespace ngraph
return values + bias_node;
}
shared_ptr<Node>
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias)
shared_ptr<Node> l2_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
float bias,
BiasMode bias_mode)
{
shared_ptr<Node> values{make_shared<op::Sum>(value * value, reduction_axes)};
......@@ -106,8 +109,16 @@ namespace ngraph
op::Constant::create(values->get_element_type(),
values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))};
return {make_shared<op::Sqrt>(values + bias_node)};
switch (bias_mode)
{
case BiasMode::MAX:
{
return {make_shared<op::Sqrt>(make_shared<op::Maximum>(values, bias_node))};
}
case BiasMode::ADD:
default: { return {make_shared<op::Sqrt>(values + bias_node)};
}
}
}
shared_ptr<Node> lp_norm(const Output<Node>& value,
......
......@@ -25,6 +25,15 @@ namespace ngraph
{
namespace builder
{
/// \brief Specyfies method of bias application to avoid numerical problems
enum class BiasMode
{
// Add bias to intermediate result
ADD,
// Calculate max of intermediate result and bias
MAX
};
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
......@@ -57,12 +66,15 @@ namespace ngraph
///
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
/// \param[in] bias The bias combined with calculated sum.
/// \param[in] bias_mode The method of bias application.
///
/// \return L-2 norm of value.
///
std::shared_ptr<Node>
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias = 0.f);
std::shared_ptr<Node> l2_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
float bias = 0.f,
BiasMode bias_mode = BiasMode::ADD);
/// \brief Creates node which calculates L-p norm on input tensor.
///
......
......@@ -123,7 +123,8 @@ namespace ngraph
auto l2_norm_reduction = std::bind(ngraph::builder::l2_norm,
std::placeholders::_1,
std::placeholders::_2,
0.f);
0.f,
ngraph::builder::BiasMode::ADD);
return {reduction::make_ng_reduction_op(
node, node.get_ng_inputs().at(0), l2_norm_reduction)};
}
......
......@@ -131,7 +131,7 @@ namespace ngraph
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
......
......@@ -18,75 +18,54 @@
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "normalize.hpp"
using namespace std;
using namespace ngraph;
const string op::Normalize::type_name{"Normalize"};
const string op::NormalizeL2::type_name{"NormalizeL2"};
op::Normalize::Normalize(const Output<Node>& data,
const Output<Node>& scale,
bool across_spatial,
bool channel_shared,
float eps)
: FusedOp({data, scale})
, m_across_spatial{across_spatial}
, m_channel_shared{channel_shared}
op::NormalizeL2::NormalizeL2(const Output<Node>& data,
const Output<Node>& axes,
float eps,
EpsMode eps_mode)
: FusedOp({data, axes})
, m_eps{eps}
, m_eps_mode{eps_mode}
{
constructor_validate_and_infer_types();
}
void op::Normalize::pre_validate_and_infer_types()
void op::NormalizeL2::pre_validate_and_infer_types()
{
const auto& data_pshape = get_input_partial_shape(0);
const auto& scale_pshape = get_input_partial_shape(1);
if (data_pshape.is_static() && scale_pshape.is_static())
{
const Shape data_shape{data_pshape.to_shape()};
const Shape scale_shape{scale_pshape.to_shape()};
// Input data must be 2, 3 or 4D tensor.
NODE_VALIDATION_CHECK(this,
(data_shape.size() >= 2 && data_shape.size() <= 4),
"Input tensor rank must be 2, 3 or 4 dimensional (actual input "
"shape: ",
data_shape,
").");
if (m_channel_shared)
{
NODE_VALIDATION_CHECK(this,
scale_shape.size() == 0,
"Scale must be a scalar if 'channels_shared' parameter is true");
}
else
{
// only HW
if (data_shape.size() == 2)
{
NODE_VALIDATION_CHECK(this,
scale_shape.size() == 0,
"Scale must be a scalar if input tensor is of rank 2.");
}
else
{
size_t n_channels = data_shape.size() == 3 ? data_shape.at(0) : data_shape.at(1);
NODE_VALIDATION_CHECK(
this,
(scale_shape.size() == 1 && scale_shape.at(0) == n_channels),
"Scale must be a vector of size of input tensor channels if input tensor is "
"of rank greater equal 3.");
}
}
}
const auto& axes_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this, data_pshape.is_static(), "Input data must be static.");
NODE_VALIDATION_CHECK(this, axes_pshape.is_static(), "Input axes must be static.");
const Shape data_shape{data_pshape.to_shape()};
// Input data must be 2, 3 or 4D tensor.
NODE_VALIDATION_CHECK(this,
(data_shape.size() >= 2 && data_shape.size() <= 4),
"Input tensor rank must be 2, 3 or 4 dimensional (actual input "
"shape: ",
data_shape,
").");
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(axes_pshape.rank()) == 1,
"Input axes must have rank equals 1 (axes shape: ",
axes_pshape,
").");
}
NodeVector op::Normalize::decompose_op() const
NodeVector op::NormalizeL2::decompose_op() const
{
Output<Node> data{input_value(0)};
const Shape input_shape{data.get_shape()};
......@@ -99,33 +78,23 @@ NodeVector op::Normalize::decompose_op() const
data = builder::reshape(data, data_shape);
}
// Calculate norm over CHW axes.
AxisSet reduction_axes{1, 2, 3};
if (m_across_spatial)
{
// Calculate norm only onver HW axes.
reduction_axes = AxisSet{2, 3};
}
// Calculate l2 norm across channels.
Output<Node> norm = builder::l2_norm(data, reduction_axes, m_eps);
norm = make_broadcast_node(norm, data.get_shape(), 0);
auto axes_node = input(1).get_source_output().get_node_shared_ptr();
NODE_VALIDATION_CHECK(this,
axes_node->is_constant(),
"doesn't support 'axes' input of other type than a Constant.");
Output<Node> scale_node{input_value(1)};
// Calculate norm over axes indicated by axes input param
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes_vector = axes_constant->get_vector<size_t>();
AxisSet reduction_axes{axes_vector};
// Broadcast scale to data tensor shape.
if (m_channel_shared)
{
// Scale is a scalar.
scale_node = make_broadcast_node(scale_node, data.get_shape());
}
else
{
// Scale is a vector of size equal to C axis.
scale_node = make_broadcast_node(scale_node, data.get_shape(), 1);
}
// Calculate l2 norm across axes determined by axes input
auto builder_bias_mode =
(m_eps_mode == EpsMode::MAX) ? builder::BiasMode::MAX : builder::BiasMode::ADD;
Output<Node> norm = builder::l2_norm(data, reduction_axes, m_eps, builder_bias_mode);
norm = make_broadcast_node(norm, data.get_shape(), 0);
data = data / norm * scale_node;
data = data / norm;
// get back original input tensor rank
if (input_shape.size() != 4)
......@@ -136,12 +105,11 @@ NodeVector op::Normalize::decompose_op() const
return as_node_vector({data});
}
shared_ptr<Node> op::Normalize::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::NormalizeL2::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<Normalize>(
new_args.at(0), new_args.at(1), m_across_spatial, m_channel_shared, m_eps);
return make_shared<NormalizeL2>(new_args.at(0), new_args.at(1), m_eps, m_eps_mode);
}
......@@ -19,6 +19,7 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
......@@ -27,31 +28,28 @@ namespace ngraph
{
/// \brief Normalization input tensor with L2 norm.
///
class Normalize : public ngraph::op::util::FusedOp
class NormalizeL2 : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Normalize() = default;
NormalizeL2() = default;
///
/// \brief Constructs a Normalize operation.
///
/// \param data - Node producing the input tensor
/// \param scale - Node producing the scale tensor
/// \param across_spatial - Whether calculate norm across all channels.
/// \param channel_shared - Whether scale is shared across channels.
/// \param axes - Node indicating axes along which reduction is calculated
/// \param eps - The epsilon added to L2 norm.
/// \param eps_mode - Specifies how eps is combined with L2 value calculated before division
///
Normalize(const Output<Node>& data,
const Output<Node>& scale,
bool across_spatial,
bool channel_shared,
float eps);
NormalizeL2(const Output<Node>& data,
const Output<Node>& axes,
float eps,
EpsMode eps_mode);
float get_across_spatial() const { return m_across_spatial; }
float get_channel_shared() const { return m_channel_shared; }
float get_eps() const { return m_eps; }
EpsMode get_eps_mode() const { return m_eps_mode; }
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
......@@ -59,9 +57,8 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
protected:
bool m_across_spatial{false};
bool m_channel_shared{false};
float m_eps{1.f};
float m_eps;
EpsMode m_eps_mode;
};
}
}
......@@ -39,7 +39,7 @@ NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
......
......@@ -81,6 +81,15 @@ namespace ngraph
NUMPY
};
/// \brief Specifies how eps is combined with L2 value
enum class EpsMode
{
// Add bias to norm
ADD,
// Calculate max of norm and bias
MAX
};
/// \brief Implicit broadcast specification
struct AutoBroadcastSpec
{
......
......@@ -90,7 +90,7 @@
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/shuffle_channels.hpp"
......@@ -2072,7 +2072,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough:
case OP_TYPEID::RNNCell:
......@@ -2199,7 +2199,7 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case OP_TYPEID::GRUCell:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::MVN:
case OP_TYPEID::Normalize:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScaleShift:
......
......@@ -68,14 +68,14 @@ gather_3d_indices_no_axis_2d_input
gather_4d_indices_no_axis_2d_input
gemm
gemm_broadcast_input_C
normalize_across_chw_scalar_scale_4d
normalize_across_chw_scalar_scale_3d
normalize_across_chw_scalar_scale_2d
normalize_across_chw_w_scale
normalize_across_hw_w_scale
normalize_across_chw_4d
normalize_across_chw_4d_max_bias
normalize_across_chw_3d
normalize_across_chw_2d
normalize_across_hw_4d
normalize_invalid_input_tensor_rank
normalize_invalid_scale_rank
normalize
normalize_invalid_axes_rank
normalize_output_shape_across_chw
hardsigmoid
model_erf
model_erf_int32
......
......@@ -182,10 +182,11 @@ conv_bias_bprop_2d
conv_bias_add_2d
space_to_depth
depth_to_space
normalize_across_chw_scalar_scale_4d
normalize_across_chw_scalar_scale_3d
normalize_across_chw_scalar_scale_2d
normalize_across_chw_w_scale
normalize_across_chw_4d
normalize_across_chw_4d_max_bias
normalize_across_chw_3d
normalize_across_chw_2d
normalize_across_hw_4d
gemm
fused_clamp
mvn_mean_normalization
......
......@@ -84,7 +84,7 @@
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/normalize.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/rnn_cell.hpp"
#include "ngraph/op/fused/scale_shift.hpp"
......@@ -145,6 +145,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
......@@ -1503,13 +1504,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Negative>(args[0]);
break;
}
case OP_TYPEID::Normalize:
case OP_TYPEID::NormalizeL2:
{
bool across_spatial = node_js.at("across_spatial").get<bool>();
bool channel_shared = node_js.at("channel_shared").get<bool>();
float eps = node_js.at("eps").get<float>();
node =
make_shared<op::Normalize>(args[0], args[1], across_spatial, channel_shared, eps);
auto eps_mode = node_js.at("eps_mode").get<op::EpsMode>();
node = make_shared<op::NormalizeL2>(args[0], args[1], eps, eps_mode);
break;
}
case OP_TYPEID::NotEqual:
......@@ -2617,12 +2616,11 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Negative: { break;
}
case OP_TYPEID::Normalize:
case OP_TYPEID::NormalizeL2:
{
auto tmp = dynamic_cast<const op::Normalize*>(&n);
node["across_spatial"] = tmp->get_across_spatial();
node["channel_shared"] = tmp->get_channel_shared();
auto tmp = dynamic_cast<const op::NormalizeL2*>(&n);
node["eps"] = tmp->get_eps();
node["eps_mode"] = tmp->get_eps_mode();
break;
}
case OP_TYPEID::NotEqual:
......
......@@ -26,6 +26,7 @@
#include "gtest/gtest.h"
#include "ngraph/check.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
......@@ -577,17 +578,16 @@ NGRAPH_TEST(${BACKEND_NAME}, depth_to_space)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_4d)
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_4d)
{
Shape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{});
bool across_spatial{false};
bool channel_shared{true};
const auto axes = make_shared<op::Constant>(element::u64, Shape{3}, vector<int64_t>{1, 2, 3});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data, scale});
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -595,28 +595,26 @@ NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_4d)
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_input<float>({2.f});
test_case.add_expected_output<float>(
data_shape, {0.02857143f, 0.05714286f, 0.08571429f, 0.11428571f, 0.14285714f, 0.17142857f,
0.2f, 0.22857143f, 0.25714286f, 0.28571429f, 0.31428571f, 0.34285714f,
0.37142857f, 0.4f, 0.42857143f, 0.45714286f, 0.48571429f, 0.51428571f,
0.54285714f, 0.57142857f, 0.6f, 0.62857143f, 0.65714286f, 0.68571429f});
data_shape, {0.01428571f, 0.02857143f, 0.04285714f, 0.05714286f, 0.07142857f, 0.08571429f,
0.1f, 0.11428571f, 0.12857144f, 0.14285715f, 0.15714286f, 0.17142858f,
0.18571429f, 0.2f, 0.21428572f, 0.22857143f, 0.24285714f, 0.25714287f,
0.27142859f, 0.2857143f, 0.30000001f, 0.31428573f, 0.32857144f, 0.34285715f});
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_3d)
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_3d)
{
Shape data_shape{2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{});
bool across_spatial{false};
bool channel_shared{true};
const auto axes = make_shared<op::Constant>(element::u64, Shape{3}, vector<int64_t>{1, 2, 3});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data, scale});
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -624,28 +622,26 @@ NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_3d)
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_input<float>({2.f});
test_case.add_expected_output<float>(
data_shape, {0.02857143f, 0.05714286f, 0.08571429f, 0.11428571f, 0.14285714f, 0.17142857f,
0.2f, 0.22857143f, 0.25714286f, 0.28571429f, 0.31428571f, 0.34285714f,
0.37142857f, 0.4f, 0.42857143f, 0.45714286f, 0.48571429f, 0.51428571f,
0.54285714f, 0.57142857f, 0.6f, 0.62857143f, 0.65714286f, 0.68571429f});
data_shape, {0.01428571f, 0.02857143f, 0.04285714f, 0.05714286f, 0.07142857f, 0.08571429f,
0.1f, 0.11428571f, 0.12857144f, 0.14285715f, 0.15714286f, 0.17142858f,
0.18571429f, 0.2f, 0.21428572f, 0.22857143f, 0.24285714f, 0.25714287f,
0.27142859f, 0.2857143f, 0.30000001f, 0.31428573f, 0.32857144f, 0.34285715f});
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_2d)
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_2d)
{
Shape data_shape{3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{});
bool across_spatial{false};
bool channel_shared{true};
const auto axes = make_shared<op::Constant>(element::u64, Shape{3}, vector<int64_t>{1, 2, 3});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data, scale});
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -653,36 +649,34 @@ NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_scalar_scale_2d)
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_input<float>({2.f});
test_case.add_expected_output<float>(data_shape,
{0.07844645f,
{0.03922323f,
0.07844646f,
0.11766968f,
0.15689291f,
0.19611613f,
0.23533936f,
0.2745626f,
0.31378582f,
0.35300905f,
0.39223227f,
0.47067872f,
0.54912518f,
0.62757163f,
0.70601809f,
0.78446454f,
0.86291099f,
0.94135745f});
0.43145549f,
0.47067872f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_w_scale)
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_empty_axes_input)
{
Shape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{2});
bool across_spatial{false};
bool channel_shared{false};
const auto axes = make_shared<op::Constant>(element::u64, Shape{0}, vector<int64_t>{});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data, scale});
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -690,29 +684,23 @@ NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_w_scale)
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_input<float>({2.f, 3.f});
test_case.add_expected_output<float>(
data_shape, {0.02857143f, 0.05714286f, 0.08571429f, 0.11428571f, 0.14285714f, 0.17142857f,
0.2f, 0.22857143f, 0.25714286f, 0.28571429f, 0.31428571f, 0.34285714f,
0.55714286f, 0.6f, 0.64285714f, 0.68571429f, 0.72857143f, 0.77142857f,
0.81428571f, 0.85714286f, 0.9f, 0.94285714f, 0.98571429f, 1.02857143f});
// output should be filled with 1f values
test_case.add_expected_output<float>(data_shape, vector<float>(shape_size(data_shape), 1));
test_case.run();
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
// TODO lower tolerance; mismatch at 4th decimal positions
NGRAPH_TEST(DISABLED_${BACKEND_NAME}, normalize_across_hw_w_scale)
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_hw_4d)
{
Shape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{2});
bool across_spatial{true};
bool channel_shared{false};
float eps{0.25f};
const auto axes = make_shared<op::Constant>(element::u64, Shape{2}, vector<int64_t>{2, 3});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data, scale});
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
......@@ -720,14 +708,40 @@ NGRAPH_TEST(DISABLED_${BACKEND_NAME}, normalize_across_hw_w_scale)
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_input<float>({2.f, 3.f});
test_case.add_expected_output<float>(
data_shape, {0.07844646f, 0.15689291f, 0.23533936f, 0.31378582f, 0.39223227f, 0.47067872f,
0.5491252f, 0.62757164f, 0.7060181f, 0.78446454f, 0.862911f, 0.94135743f,
0.5982327f, 0.64425063f, 0.6902685f, 0.7362864f, 0.7823043f, 0.8283222f,
0.87434006f, 0.920358f, 0.9663758f, 1.0123938f, 1.0584116f, 1.1044296f});
test_case.run();
data_shape, {0.03922323f, 0.07844646f, 0.11766968f, 0.15689291f, 0.19611613f, 0.23533936f,
0.2745626f, 0.31378582f, 0.35300905f, 0.39223227f, 0.43145549f, 0.47067872f,
0.1994109f, 0.2147502f, 0.2300895f, 0.2454288f, 0.26076809f, 0.2761074f,
0.29144669f, 0.306786f, 0.32212529f, 0.3374646f, 0.35280389f, 0.3681432f});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(${BACKEND_NAME}, normalize_across_chw_4d_max_bias)
{
Shape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto axes = make_shared<op::Constant>(element::u64, Shape{3}, vector<int64_t>{1, 2, 3});
float eps{5000};
auto eps_mode = op::EpsMode::MAX;
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
auto function = make_shared<Function>(NodeVector{normalize}, ParameterVector{data});
auto test_case = test::NgraphTestCase(function, "${BACKEND_NAME}");
vector<float> input_data(shape_size(data_shape));
iota(begin(input_data), end(input_data), 1);
test_case.add_input<float>(input_data);
test_case.add_expected_output<float>(
data_shape, {0.01414214f, 0.02828427f, 0.04242641f, 0.05656854f, 0.07071068f, 0.08485281f,
0.09899495f, 0.11313709f, 0.12727922f, 0.14142136f, 0.15556349f, 0.16970563f,
0.18384777f, 0.1979899f, 0.21213204f, 0.22627418f, 0.2404163f, 0.25455844f,
0.26870057f, 0.28284273f, 0.29698485f, 0.31112698f, 0.32526913f, 0.33941126f});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
NGRAPH_TEST(${BACKEND_NAME}, gemm)
......
#!/usr/bin/env python
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
input = np.arange(1, 25, 1).reshape(1, 2, 3, 4).astype(np.float32)
eps = np.array([1e-6]).astype(np.float32)
# across chw axes
norm = np.sqrt(np.maximum(np.sum(np.power(input, 2), axis=(1, 2, 3)), eps))
result = input/norm
for elem in np.nditer(result):
print(str(round(elem, 8)) + 'f, ')
......@@ -20,20 +20,17 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, normalize_invalid_input_tensor_rank)
{
Shape data_shape{1, 2, 3, 4, 5};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{});
bool across_spatial{false};
bool channel_shared{true};
auto axes = make_shared<op::Parameter>(element::u64, Shape{1, 2});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
try
{
auto normalize =
make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
}
......@@ -51,8 +48,7 @@ TEST(type_prop, normalize_invalid_input_tensor_rank)
try
{
auto normalize =
make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
}
......@@ -67,64 +63,23 @@ TEST(type_prop, normalize_invalid_input_tensor_rank)
}
}
TEST(type_prop, normalize_invalid_scale_rank)
TEST(type_prop, normalize_invalid_axes_rank)
{
Shape data_shape{1, 2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{3});
bool across_spatial{false};
bool channel_shared{true};
auto axes = make_shared<op::Parameter>(element::u64, Shape{1, 2});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
try
{
auto normalize =
make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Scale must be a scalar if 'channels_shared' "
"parameter is true"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
channel_shared = false;
try
{
auto normalize =
make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Scale must be a vector of size of input tensor "
"channels"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
data = make_shared<op::Parameter>(element::f32, Shape{4, 3});
try
{
auto normalize =
make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Scale must be a scalar if input tensor is of rank 2"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input axes must have rank equals 1"));
}
catch (...)
{
......@@ -132,16 +87,15 @@ TEST(type_prop, normalize_invalid_scale_rank)
}
}
TEST(type_prop, normalize)
TEST(type_prop, normalize_output_shape_across_chw)
{
Shape data_shape{2, 3, 4};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto scale = make_shared<op::Parameter>(element::f32, Shape{2});
bool across_spatial{false};
bool channel_shared{false};
const auto axes = make_shared<op::Constant>(element::u64, Shape{3}, vector<int64_t>{1, 2, 3});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize = make_shared<op::Normalize>(data, scale, across_spatial, channel_shared, eps);
auto normalize = make_shared<op::NormalizeL2>(data, axes, eps, eps_mode);
EXPECT_EQ(normalize->get_element_type(), element::f32);
EXPECT_EQ(normalize->get_shape(), (Shape{2, 3, 4}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment