Commit 58dc9d09 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Fused Op] Add new fused operator: MVN(Mean Variance Normalization) (#2887)

* Basic mean normalization

* Add NVM to serializer

* Add test for mean normalization

* Add support for across_channel atribute

* Add test for mvn_mean_normalization splited by channels

* Assume that data have n and c channels

* Add support for normalize_variance attribute

* Add test for full mean variance normalization

* Add type prop test

* Skip tests on GPU

* Use ngraph builder functions instead of my own

* Update mvn.cpp

* Change order in initializer list

* Review fix
parent 48991823
......@@ -290,6 +290,8 @@ set (SRC
op/fused/gemm.hpp
op/fused/group_conv.hpp
op/fused/group_conv.cpp
op/fused/mvn.cpp
op/fused/mvn.hpp
op/fused/prelu.cpp
op/fused/prelu.hpp
op/fused/space_to_depth.cpp
......
......@@ -102,6 +102,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "mvn.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
op::MVN::MVN(const std::shared_ptr<Node>& data,
bool across_channels,
bool normalize_variance,
double eps)
: FusedOp("MVN", {data})
, m_eps{eps}
, m_across_channels{across_channels}
, m_normalize_variance{normalize_variance}
{
constructor_validate_and_infer_types();
}
NodeVector op::MVN::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape(); // assume that data has n and c channels.
// if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel
AxisSet reduction_axes;
size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data_shape.size(); ++i)
{
reduction_axes.insert(i);
}
// calculate mean normalization
auto mean = builder::mean(data, reduction_axes);
mean = legacy_style_broadcast_for_binary_operation(data, mean, 0).at(1);
auto mean_normalization = data - mean;
if (!m_normalize_variance)
{
return {mean_normalization};
}
else
{
// calculate variance
auto variance = builder::variance(mean_normalization, reduction_axes);
variance = make_shared<op::Sqrt>(variance);
// add epsilon
auto eps_node = op::Constant::create(
data->get_element_type(), variance->get_shape(), vector<double>{m_eps});
variance = variance + eps_node;
variance =
legacy_style_broadcast_for_binary_operation(mean_normalization, variance, 0).at(1);
return {mean_normalization / variance};
}
}
shared_ptr<Node> op::MVN::copy_with_new_args(const NodeVector& new_args) const
{
NODE_VALIDATION_CHECK(this,
new_args.size() == 1,
"Expected 1 element in new_args for the MVN op but got ",
new_args.size());
return make_shared<MVN>(new_args.at(0), m_across_channels, m_normalize_variance, m_eps);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Operator performing Mean Variance Normalization
///
class MVN : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs an MVN operation.
///
/// \param data Input tensor with data
/// \param normalize_variance flag that denotes whether to perform variance normalization.
/// \param across_channels flag that denotes if mean values are shared across channels.
/// \param eps the number to be added to the variance to avoid division by zero when normalizing the value
///
MVN(const std::shared_ptr<ngraph::Node>& data,
bool across_channels = true,
bool normalize_variance = true,
double eps = 1e-9);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
double get_eps() const { return m_eps; }
bool get_across_channels() const { return m_across_channels; }
bool get_normalize_variance() const { return m_normalize_variance; }
private:
const double m_eps;
const bool m_across_channels;
const bool m_normalize_variance;
};
} // namespace op
} // namespace ngraph
......@@ -17,14 +17,15 @@
// This collection contains one entry for each fused op.
//
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
......@@ -156,3 +156,7 @@ gather_nd_single_indices
gemm
gemm_broadcast_input_C
model_hardmax
mvn_mean_normalization
mvn_mean_normalization_split_channels
mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
......@@ -83,6 +83,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
......@@ -1991,6 +1992,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::Gemm:
case OP_TYPEID::GenerateMask:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::MVN:
case OP_TYPEID::PRelu:
case OP_TYPEID::Passthrough:
case OP_TYPEID::QuantizedAvgPool:
......
......@@ -67,4 +67,8 @@ gather_nd_single_indices
gemm
gemm_broadcast_input_C
hardsigmoid
mvn_mean_normalization
mvn_mean_normalization_split_channels
mvn_mean_variance_normalization
mvn_mean_variance_normalization_split_channels
zero_sized_erf
......@@ -73,6 +73,7 @@
#include "ngraph/op/fused/gemm.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "ngraph/op/fused/prelu.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/gather.hpp"
......@@ -1131,6 +1132,14 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Multiply>(args[0], args[1]);
break;
}
case OP_TYPEID::MVN:
{
auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto across_channels = node_js.at("across_channels").get<bool>();
auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, across_channels, eps);
break;
}
case OP_TYPEID::Negative:
{
node = make_shared<op::Negative>(args[0]);
......@@ -1934,6 +1943,14 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Multiply: { break;
}
case OP_TYPEID::MVN:
{
auto tmp = dynamic_cast<const op::MVN*>(&n);
node["normalize_variance"] = tmp->get_normalize_variance();
node["across_channels"] = tmp->get_across_channels();
node["eps"] = tmp->get_eps();
break;
}
case OP_TYPEID::Negative: { break;
}
case OP_TYPEID::NotEqual: { break;
......
......@@ -464,3 +464,101 @@ NGRAPH_TEST(${BACKEND_NAME}, fused_clamp)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_normalization)
{
Shape data_shape{1, 2, 5};
auto data = make_shared<op::Parameter>(element::f64, data_shape);
auto mvn_func = make_shared<op::MVN>(data, true, false);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// data
vector<double> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<double>(data_vector);
// expected result
test_case.add_expected_output<double>(
data_shape, vector<double>{-4.5, -3.5, -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, 3.5, 4.5});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_normalization_split_channels)
{
Shape data_shape{1, 2, 5, 1};
auto data = make_shared<op::Parameter>(element::f64, data_shape);
auto mvn_func = make_shared<op::MVN>(data, false, false);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// data
vector<double> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<double>(data_vector);
// expected result
test_case.add_expected_output<double>({1, 2, 5, 1},
vector<double>{-2, -1, 0, 1, 2, -2, -1, 0, 1, 2});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization)
{
Shape data_shape{1, 2, 5};
auto data = make_shared<op::Parameter>(element::f64, data_shape);
auto mvn_func = make_shared<op::MVN>(data);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// data
vector<double> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<double>(data_vector);
// expected result
test_case.add_expected_output<double>(data_shape,
vector<double>{-1.566698903055826,
-1.2185435912656424,
-0.87038827947545883,
-0.52223296768527527,
-0.17407765589509178,
0.17407765589509178,
0.52223296768527527,
0.87038827947545883,
1.2185435912656424,
1.566698903055826});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
{
Shape data_shape{1, 2, 5};
auto data = make_shared<op::Parameter>(element::f64, data_shape);
auto mvn_func = make_shared<op::MVN>(data, false);
auto function = make_shared<Function>(NodeVector{mvn_func}, ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
// data
vector<double> data_vector(shape_size(data_shape));
iota(begin(data_vector), end(data_vector), 0);
test_case.add_input<double>(data_vector);
// expected result
test_case.add_expected_output<double>(data_shape,
vector<double>{-1.4142135613730948,
-0.70710678068654742,
0.000000000000000,
0.70710678068654742,
1.4142135613730948,
-1.4142135613730948,
-0.70710678068654742,
0.000000000000000,
0.70710678068654742,
1.4142135613730948});
test_case.run();
}
......@@ -14010,6 +14010,14 @@ TEST(type_prop, gemm_broadcast_input_C)
EXPECT_EQ(gemm_func->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, mvn)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto mvn_func = make_shared<op::MVN>(data);
EXPECT_EQ(mvn_func->get_element_type(), element::f32);
EXPECT_EQ(mvn_func->get_shape(), (Shape{1, 3, 6}));
}
TEST(type_prop, fused_clamp)
{
const auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment