Commit 22c4f3fb authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Add support for onnx MeanVarianceNormalization op (#3065)

* [ONNX] Add support for onnx MeanVarianceNormalization op

* Fix docstring

* Delete commented code

* Add support for legacy onnx MVN version

* Fix a typo in attribute name

* Fix for legacy onnx MVN

* Add a header for int64_t
parent 6025acc5
......@@ -121,6 +121,8 @@ add_library(onnx_import STATIC
op/max.hpp
op/mean.cpp
op/mean.hpp
op/mean_variance_normalization.cpp
op/mean_variance_normalization.hpp
op/min.hpp
op/mul.hpp
op/neg.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <memory>
#include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
bool across_channels =
node.get_attribute_value<std::int64_t>("across_channels", 0);
bool normalize_variance =
node.get_attribute_value<std::int64_t>("normalize_variance", 1);
return {std::make_shared<ngraph::op::MVN>(
data, across_channels, normalize_variance)};
}
} // namespace set_1
namespace set_9
{
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<size_t>>("axes", {0, 2, 3});
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(axes))};
}
} // namespace set_9
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean_variance_normalization(const Node& node);
} // namespace set_1
namespace set_9
{
NodeVector mean_variance_normalization(const Node& node);
} // namespace set_9
} //namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -75,6 +75,7 @@
#include "op/max.hpp"
#include "op/max_pool.hpp"
#include "op/mean.hpp"
#include "op/mean_variance_normalization.hpp"
#include "op/min.hpp"
#include "op/mul.hpp"
#include "op/neg.hpp"
......@@ -283,6 +284,8 @@ namespace ngraph
REGISTER_OPERATOR("Max", 8, max);
REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Mean", 8, mean);
REGISTER_OPERATOR("MeanVarianceNormalization", 1, mean_variance_normalization);
REGISTER_OPERATOR("MeanVarianceNormalization", 9, mean_variance_normalization);
REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Min", 8, min);
REGISTER_OPERATOR("Mul", 1, mul);
......
......@@ -37,25 +37,38 @@ op::MVN::MVN(const std::shared_ptr<Node>& data,
, m_normalize_variance{normalize_variance}
{
constructor_validate_and_infer_types();
}
NodeVector op::MVN::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape(); // assume that data has n and c channels.
// if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel
AxisSet reduction_axes;
m_reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data_shape.size(); ++i)
for (size_t i = start_axis; i < data->get_shape().size(); ++i)
{
reduction_axes.insert(i);
m_reduction_axes.insert(i);
}
}
op::MVN::MVN(const std::shared_ptr<Node>& data,
AxisSet reduction_axes,
bool normalize_variance,
double eps)
: FusedOp("MVN", {data})
, m_eps{eps}
, m_across_channels{false}
, m_normalize_variance{normalize_variance}
, m_reduction_axes{reduction_axes}
{
constructor_validate_and_infer_types();
}
NodeVector op::MVN::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape(); // assume that data has n and c channels.
// calculate mean normalization
auto mean = builder::mean(data, reduction_axes);
mean = legacy_style_broadcast_for_binary_operation(data, mean, 0).at(1);
auto mean = builder::mean(data, m_reduction_axes);
mean = std::make_shared<op::Broadcast>(mean, data_shape, m_reduction_axes);
auto mean_normalization = data - mean;
if (!m_normalize_variance)
......@@ -65,14 +78,13 @@ NodeVector op::MVN::decompose_op() const
else
{
// calculate variance
auto variance = builder::variance(mean_normalization, reduction_axes);
auto variance = builder::variance(data, m_reduction_axes);
variance = make_shared<op::Sqrt>(variance);
// add epsilon
auto eps_node = op::Constant::create(
data->get_element_type(), variance->get_shape(), vector<double>{m_eps});
variance = variance + eps_node;
variance =
legacy_style_broadcast_for_binary_operation(mean_normalization, variance, 0).at(1);
variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
return {mean_normalization / variance};
}
......@@ -84,5 +96,5 @@ shared_ptr<Node> op::MVN::copy_with_new_args(const NodeVector& new_args) const
new_args.size() == 1,
"Expected 1 element in new_args for the MVN op but got ",
new_args.size());
return make_shared<MVN>(new_args.at(0), m_across_channels, m_normalize_variance, m_eps);
return make_shared<MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
}
......@@ -41,18 +41,31 @@ namespace ngraph
bool normalize_variance = true,
double eps = 1e-9);
/// \brief Constructs an MVN operation.
///
/// \param data Input tensor with data
/// \param reduction_axes A list of axes, along which to reduce.
/// \param normalize_variance flag that denotes whether to perform variance normalization.
/// \param eps the number to be added to the variance to avoid division by zero when normalizing the value
///
MVN(const std::shared_ptr<ngraph::Node>& data,
AxisSet reduction_axes,
bool normalize_variance = true,
double eps = 1e-9);
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
double get_eps() const { return m_eps; }
bool get_across_channels() const { return m_across_channels; }
bool get_normalize_variance() const { return m_normalize_variance; }
AxisSet get_reduction_axes() const { return m_reduction_axes; }
private:
const double m_eps;
const bool m_across_channels;
const bool m_normalize_variance;
AxisSet m_reduction_axes;
};
} // namespace op
} // namespace ngraph
......@@ -1382,9 +1382,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::MVN:
{
auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto across_channels = node_js.at("across_channels").get<bool>();
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, across_channels, eps);
node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps);
break;
}
case OP_TYPEID::Negative:
......@@ -2387,8 +2387,8 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::MVN:
{
auto tmp = dynamic_cast<const op::MVN*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["normalize_variance"] = tmp->get_normalize_variance();
node["across_channels"] = tmp->get_across_channels();
node["eps"] = tmp->get_eps();
break;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment