Commit 22c4f3fb authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Add support for onnx MeanVarianceNormalization op (#3065)

* [ONNX] Add support for onnx MeanVarianceNormalization op

* Fix docstring

* Delete commented code

* Add support for legacy onnx MVN version

* Fix a typo in attribute name

* Fix for legacy onnx MVN

* Add a header for int64_t
parent 6025acc5
...@@ -121,6 +121,8 @@ add_library(onnx_import STATIC ...@@ -121,6 +121,8 @@ add_library(onnx_import STATIC
op/max.hpp op/max.hpp
op/mean.cpp op/mean.cpp
op/mean.hpp op/mean.hpp
op/mean_variance_normalization.cpp
op/mean_variance_normalization.hpp
op/min.hpp op/min.hpp
op/mul.hpp op/mul.hpp
op/neg.hpp op/neg.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <memory>
#include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
bool across_channels =
node.get_attribute_value<std::int64_t>("across_channels", 0);
bool normalize_variance =
node.get_attribute_value<std::int64_t>("normalize_variance", 1);
return {std::make_shared<ngraph::op::MVN>(
data, across_channels, normalize_variance)};
}
} // namespace set_1
namespace set_9
{
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<size_t>>("axes", {0, 2, 3});
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(axes))};
}
} // namespace set_9
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector mean_variance_normalization(const Node& node);
} // namespace set_1
namespace set_9
{
NodeVector mean_variance_normalization(const Node& node);
} // namespace set_9
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -75,6 +75,7 @@ ...@@ -75,6 +75,7 @@
#include "op/max.hpp" #include "op/max.hpp"
#include "op/max_pool.hpp" #include "op/max_pool.hpp"
#include "op/mean.hpp" #include "op/mean.hpp"
#include "op/mean_variance_normalization.hpp"
#include "op/min.hpp" #include "op/min.hpp"
#include "op/mul.hpp" #include "op/mul.hpp"
#include "op/neg.hpp" #include "op/neg.hpp"
...@@ -283,6 +284,8 @@ namespace ngraph ...@@ -283,6 +284,8 @@ namespace ngraph
REGISTER_OPERATOR("Max", 8, max); REGISTER_OPERATOR("Max", 8, max);
REGISTER_OPERATOR("Mean", 1, mean); REGISTER_OPERATOR("Mean", 1, mean);
REGISTER_OPERATOR("Mean", 8, mean); REGISTER_OPERATOR("Mean", 8, mean);
REGISTER_OPERATOR("MeanVarianceNormalization", 1, mean_variance_normalization);
REGISTER_OPERATOR("MeanVarianceNormalization", 9, mean_variance_normalization);
REGISTER_OPERATOR("Min", 1, min); REGISTER_OPERATOR("Min", 1, min);
REGISTER_OPERATOR("Min", 8, min); REGISTER_OPERATOR("Min", 8, min);
REGISTER_OPERATOR("Mul", 1, mul); REGISTER_OPERATOR("Mul", 1, mul);
......
...@@ -37,25 +37,38 @@ op::MVN::MVN(const std::shared_ptr<Node>& data, ...@@ -37,25 +37,38 @@ op::MVN::MVN(const std::shared_ptr<Node>& data,
, m_normalize_variance{normalize_variance} , m_normalize_variance{normalize_variance}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
NodeVector op::MVN::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape(); // assume that data has n and c channels.
// if m_across_channels is true we should calculate mean and variance per batch // if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel // else we calculate these per channel
AxisSet reduction_axes; m_reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2; size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data_shape.size(); ++i) for (size_t i = start_axis; i < data->get_shape().size(); ++i)
{ {
reduction_axes.insert(i); m_reduction_axes.insert(i);
} }
}
op::MVN::MVN(const std::shared_ptr<Node>& data,
AxisSet reduction_axes,
bool normalize_variance,
double eps)
: FusedOp("MVN", {data})
, m_eps{eps}
, m_across_channels{false}
, m_normalize_variance{normalize_variance}
, m_reduction_axes{reduction_axes}
{
constructor_validate_and_infer_types();
}
NodeVector op::MVN::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape(); // assume that data has n and c channels.
// calculate mean normalization // calculate mean normalization
auto mean = builder::mean(data, reduction_axes); auto mean = builder::mean(data, m_reduction_axes);
mean = legacy_style_broadcast_for_binary_operation(data, mean, 0).at(1); mean = std::make_shared<op::Broadcast>(mean, data_shape, m_reduction_axes);
auto mean_normalization = data - mean; auto mean_normalization = data - mean;
if (!m_normalize_variance) if (!m_normalize_variance)
...@@ -65,14 +78,13 @@ NodeVector op::MVN::decompose_op() const ...@@ -65,14 +78,13 @@ NodeVector op::MVN::decompose_op() const
else else
{ {
// calculate variance // calculate variance
auto variance = builder::variance(mean_normalization, reduction_axes); auto variance = builder::variance(data, m_reduction_axes);
variance = make_shared<op::Sqrt>(variance); variance = make_shared<op::Sqrt>(variance);
// add epsilon // add epsilon
auto eps_node = op::Constant::create( auto eps_node = op::Constant::create(
data->get_element_type(), variance->get_shape(), vector<double>{m_eps}); data->get_element_type(), variance->get_shape(), vector<double>{m_eps});
variance = variance + eps_node; variance = variance + eps_node;
variance = variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
legacy_style_broadcast_for_binary_operation(mean_normalization, variance, 0).at(1);
return {mean_normalization / variance}; return {mean_normalization / variance};
} }
...@@ -84,5 +96,5 @@ shared_ptr<Node> op::MVN::copy_with_new_args(const NodeVector& new_args) const ...@@ -84,5 +96,5 @@ shared_ptr<Node> op::MVN::copy_with_new_args(const NodeVector& new_args) const
new_args.size() == 1, new_args.size() == 1,
"Expected 1 element in new_args for the MVN op but got ", "Expected 1 element in new_args for the MVN op but got ",
new_args.size()); new_args.size());
return make_shared<MVN>(new_args.at(0), m_across_channels, m_normalize_variance, m_eps); return make_shared<MVN>(new_args.at(0), m_reduction_axes, m_normalize_variance, m_eps);
} }
...@@ -41,18 +41,31 @@ namespace ngraph ...@@ -41,18 +41,31 @@ namespace ngraph
bool normalize_variance = true, bool normalize_variance = true,
double eps = 1e-9); double eps = 1e-9);
/// \brief Constructs an MVN operation.
///
/// \param data Input tensor with data
/// \param reduction_axes A list of axes, along which to reduce.
/// \param normalize_variance flag that denotes whether to perform variance normalization.
/// \param eps the number to be added to the variance to avoid division by zero when normalizing the value
///
MVN(const std::shared_ptr<ngraph::Node>& data,
AxisSet reduction_axes,
bool normalize_variance = true,
double eps = 1e-9);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
double get_eps() const { return m_eps; } double get_eps() const { return m_eps; }
bool get_across_channels() const { return m_across_channels; }
bool get_normalize_variance() const { return m_normalize_variance; } bool get_normalize_variance() const { return m_normalize_variance; }
AxisSet get_reduction_axes() const { return m_reduction_axes; }
private: private:
const double m_eps; const double m_eps;
const bool m_across_channels; const bool m_across_channels;
const bool m_normalize_variance; const bool m_normalize_variance;
AxisSet m_reduction_axes;
}; };
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
...@@ -1382,9 +1382,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1382,9 +1382,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
{ {
auto normalize_variance = node_js.at("normalize_variance").get<bool>(); auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto across_channels = node_js.at("across_channels").get<bool>(); auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
auto eps = node_js.at("eps").get<double>(); auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, across_channels, eps); node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps);
break; break;
} }
case OP_TYPEID::Negative: case OP_TYPEID::Negative:
...@@ -2387,8 +2387,8 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -2387,8 +2387,8 @@ json JSONSerializer::serialize_node(const Node& n)
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
{ {
auto tmp = dynamic_cast<const op::MVN*>(&n); auto tmp = dynamic_cast<const op::MVN*>(&n);
node["reduction_axes"] = tmp->get_reduction_axes();
node["normalize_variance"] = tmp->get_normalize_variance(); node["normalize_variance"] = tmp->get_normalize_variance();
node["across_channels"] = tmp->get_across_channels();
node["eps"] = tmp->get_eps(); node["eps"] = tmp->get_eps();
break; break;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment