Commit 385770d8 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

LayerNorm (#3678)

* LayerNorm (#3630)

* Constructors.

Type prop.

Decompose LayerNorm.

Add serialize.

* Add dummy test case.

* Add dummy type prop test.

* Fix some build errors.

* Remove build errors.

* Update decompose for bprop.

* Change begin_norm_axis default value to 1.

* Style.

* Reorder class members.

* Add actual type prop tests.

* Add fprop test.

* Working on bprop test.

* Bprop tests.

* Allow flattened scale and bias.

* Add support for flattened scale and bias.

* Fix incorrect type_name.

* PlaidML: Decompose fused_op LayerNorm

* Update Backprop constructors.

* PlaidML: Add missing header file.

* Remove doc about removed param.

* Fix type prop tests.

* PlaidML: Disable unit test.

* Fix stats flattening axes bug.

* Upgrade description to type_info.
parent 40eb9587
......@@ -336,6 +336,8 @@ set (SRC
op/fused/group_conv_transpose.cpp
op/fused/gru_cell.cpp
op/fused/gru_cell.hpp
op/fused/layer_norm.cpp
op/fused/layer_norm.hpp
op/fused/lstm_cell.cpp
op/fused/lstm_cell.hpp
op/fused/matmul.cpp
......
......@@ -136,6 +136,7 @@ namespace ngraph
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <numeric>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::LayerNorm::type_info;
constexpr NodeTypeInfo op::LayerNormBackprop::type_info;
op::LayerNorm::LayerNorm(const Output<Node>& data,
const Output<Node>& scale,
const Output<Node>& bias,
bool keep_stats,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data, scale, bias})
, m_keep_stats(keep_stats)
, m_use_affine(true)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
op::LayerNorm::LayerNorm(const Output<Node>& data,
bool keep_stats,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data})
, m_keep_stats(keep_stats)
, m_use_affine(false)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
// All input shape should be static by this point
NodeVector op::LayerNorm::decompose_op() const
{
const PartialShape& data_shape = get_input_partial_shape(0);
if (data_shape.is_dynamic())
{
throw ngraph_error("Data needs to have static shape to decompose");
}
if (m_use_affine)
{
const PartialShape& scale_shape = get_input_partial_shape(1);
const PartialShape& bias_shape = get_input_partial_shape(2);
if (scale_shape.is_dynamic())
{
throw ngraph_error("Scale needs to have static shape to decompose");
}
if (bias_shape.is_dynamic())
{
throw ngraph_error("Bias needs to have static shape to decompose");
}
}
// Compute real axis
auto shape = data_shape.to_shape();
int64_t n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : shape.size() + m_begin_norm_axis;
// Get input data
auto data = input_value(0);
// Compute mean
std::vector<size_t> post_reduction_axes(shape.size() - n_axis);
std::iota(post_reduction_axes.begin(), post_reduction_axes.end(), n_axis);
auto mean = builder::mean(data, post_reduction_axes);
AxisSet post_axis_set;
for (size_t i = static_cast<size_t>(n_axis); i < shape.size(); i++)
{
post_axis_set.insert(i);
}
auto b_mean = make_shared<ngraph::op::Broadcast>(mean, shape, post_axis_set);
// Compute variance
auto var = builder::variance(data, post_reduction_axes);
// Compute standard deviation with epsilon
auto epsilon = builder::make_constant(var->get_element_type(), var->get_shape(), m_epsilon);
auto stddev = make_shared<op::Sqrt>(var + epsilon);
auto b_stddev = make_shared<op::Broadcast>(stddev, shape, post_axis_set);
// Get normalized input
auto norm = (data - b_mean) / b_stddev;
// Apply affine transformation
if (m_use_affine)
{
AxisSet pre_axis_set;
for (size_t i = 0; i < static_cast<size_t>(n_axis); i++)
{
pre_axis_set.insert(i);
}
auto scale = input_value(1);
auto bias = input_value(2);
auto scale_shape = get_input_partial_shape(1).to_shape();
if (shape.size() - n_axis != scale_shape.size())
{
Shape reshape_shape(shape.begin() + m_begin_norm_axis, shape.end());
scale = make_shared<op::Reshape>(scale, AxisVector{0}, reshape_shape);
bias = make_shared<op::Reshape>(bias, AxisVector{0}, reshape_shape);
}
auto b_scale = make_shared<op::Broadcast>(scale, shape, pre_axis_set);
auto b_bias = make_shared<op::Broadcast>(bias, shape, pre_axis_set);
norm = norm * b_scale + b_bias;
}
// Return output nodes
NodeVector retval;
retval.emplace_back(norm);
if (m_keep_stats)
{
retval.emplace_back(mean);
retval.emplace_back(var);
}
return retval;
}
shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1 && new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
if (!m_use_affine)
{
return make_shared<LayerNorm>(new_args.at(0), m_keep_stats, m_begin_norm_axis, m_epsilon);
}
else
{
return make_shared<LayerNorm>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_keep_stats,
m_begin_norm_axis,
m_epsilon);
}
}
void op::LayerNorm::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
const PartialShape& data_shape = get_input_partial_shape(0);
Rank data_rank = data_shape.rank();
int64_t d_rank = -1;
int64_t n_axis = -1;
if (data_rank.is_static())
{
d_rank = static_cast<int64_t>(data_rank);
n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis;
NODE_VALIDATION_CHECK(
this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range");
if (m_use_affine)
{
const PartialShape& scale_shape = get_input_partial_shape(1);
const PartialShape& bias_shape = get_input_partial_shape(2);
Rank scale_rank = scale_shape.rank();
Rank bias_rank = bias_shape.rank();
if (scale_rank.is_static() && bias_rank.is_static())
{
int64_t s_rank = static_cast<int64_t>(scale_rank);
int64_t b_rank = static_cast<int64_t>(bias_rank);
NODE_VALIDATION_CHECK(this,
s_rank == b_rank &&
((s_rank == (d_rank - n_axis)) || s_rank == 1),
"Scale and/or bias rank is incorrect");
}
}
}
if (m_keep_stats)
{
set_output_size(3);
// output shape: data_shape[:begin_norm_axis]
if (d_rank > 0)
{
std::vector<Dimension> stats_dim;
for (int64_t i = 0; i < n_axis; i++)
{
stats_dim.emplace_back(data_shape[i]);
}
PartialShape stats_shape(stats_dim);
set_output_type(1, input_element_type, stats_shape);
set_output_type(2, input_element_type, stats_shape);
}
else // set shape to dynamic
{
set_output_type(1, input_element_type, PartialShape::dynamic());
set_output_type(2, input_element_type, PartialShape::dynamic());
}
}
PartialShape norm_shape{data_shape};
set_output_type(0, input_element_type, norm_shape);
}
void op::LayerNorm::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
auto delta = deltas.at(0);
auto data = input_value(0);
if (m_use_affine)
{
auto scale = input_value(1);
auto bias = input_value(2);
if (m_keep_stats)
{
auto mean = outputs()[1];
auto variance = outputs()[2];
auto bprop = make_shared<op::LayerNormBackprop>(
data, delta, mean, variance, scale, m_begin_norm_axis, m_epsilon);
adjoints.add_delta(data, bprop->outputs()[0]);
adjoints.add_delta(scale, bprop->outputs()[1]);
adjoints.add_delta(bias, bprop->outputs()[2]);
}
else
{
auto bprop = make_shared<op::LayerNormBackprop>(
data, delta, scale, m_begin_norm_axis, m_epsilon);
adjoints.add_delta(data, bprop->outputs()[0]);
adjoints.add_delta(scale, bprop->outputs()[1]);
adjoints.add_delta(bias, bprop->outputs()[2]);
}
}
else
{
if (m_keep_stats)
{
auto mean = outputs()[1];
auto variance = outputs()[2];
auto bprop = make_shared<op::LayerNormBackprop>(
data, delta, mean, variance, m_begin_norm_axis, m_epsilon);
adjoints.add_delta(data, bprop->outputs()[0]);
}
else
{
auto bprop =
make_shared<op::LayerNormBackprop>(data, delta, m_begin_norm_axis, m_epsilon);
adjoints.add_delta(data, bprop->outputs()[0]);
}
}
}
op::LayerNormBackprop::LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& scale,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data, delta, mean, variance, scale})
, m_use_stats(true)
, m_use_affine(true)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
op::LayerNormBackprop::LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& mean,
const Output<Node>& variance,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data, delta, mean, variance})
, m_use_stats(true)
, m_use_affine(false)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
op::LayerNormBackprop::LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& scale,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data, delta, scale})
, m_use_stats(false)
, m_use_affine(true)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
op::LayerNormBackprop::LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
int64_t begin_norm_axis,
double epsilon)
: FusedOp({data, delta})
, m_use_stats(false)
, m_use_affine(false)
, m_begin_norm_axis(begin_norm_axis)
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
}
// All input shape should be static by this point
NodeVector op::LayerNormBackprop::decompose_op() const
{
const PartialShape& data_shape = get_input_partial_shape(0);
if (data_shape.is_dynamic())
{
throw ngraph_error("Data needs to have static shape to decompose");
}
const PartialShape& delta_shape = get_input_partial_shape(1);
if (delta_shape.is_dynamic())
{
throw ngraph_error("Delta needs to have static shape to decompose");
}
if (m_use_stats)
{
const PartialShape& mean_shape = get_input_partial_shape(2);
const PartialShape& var_shape = get_input_partial_shape(3);
if (mean_shape.is_dynamic())
{
throw ngraph_error("Mean needs to have static shape to decompose");
}
if (var_shape.is_dynamic())
{
throw ngraph_error("Variance needs to have static shape to decompose");
}
}
if (m_use_affine)
{
const PartialShape& scale_shape = get_input_partial_shape(m_use_stats ? 4 : 2);
if (scale_shape.is_dynamic())
{
throw ngraph_error("Scale needs to have static shape to decompose");
}
}
// Compute real axis
auto shape = data_shape.to_shape();
int64_t n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : shape.size() + m_begin_norm_axis;
// Get input data
auto data = input_value(0);
// Get delta
auto delta = input_value(1);
// Get mean
std::vector<size_t> post_reduction_axes(shape.size() - n_axis);
std::iota(post_reduction_axes.begin(), post_reduction_axes.end(), n_axis);
auto mean =
m_use_stats ? input_value(2) : builder::mean(data, post_reduction_axes)->outputs()[0];
AxisSet post_axis_set;
for (size_t i = static_cast<size_t>(n_axis); i < shape.size(); i++)
{
post_axis_set.insert(i);
}
auto b_mean = make_shared<ngraph::op::Broadcast>(mean, shape, post_axis_set);
// Get variance
auto var =
m_use_stats ? input_value(3) : builder::variance(data, post_reduction_axes)->outputs()[0];
// Compute standard deviation with epsilon
auto epsilon = builder::make_constant(var.get_element_type(), var.get_shape(), m_epsilon);
auto stddev = make_shared<op::Sqrt>(var + epsilon);
auto b_stddev = make_shared<op::Broadcast>(stddev, shape, post_axis_set);
// Get normalized input
auto norm = (data - b_mean) / b_stddev;
// Get gradient for data
auto d_data = delta / b_stddev;
bool scale_flattened = false;
if (m_use_affine)
{
AxisSet pre_axis_set;
for (size_t i = 0; i < static_cast<size_t>(n_axis); i++)
{
pre_axis_set.insert(i);
}
size_t scale_idx = m_use_stats ? 4 : 2;
auto scale = input_value(scale_idx);
auto scale_shape = get_input_partial_shape(scale_idx).to_shape();
if (shape.size() - n_axis != scale_shape.size())
{
scale_flattened = true;
Shape reshape_shape(shape.begin() + m_begin_norm_axis, shape.end());
scale = make_shared<op::Reshape>(scale, AxisVector{0}, reshape_shape);
}
auto b_scale = make_shared<op::Broadcast>(scale, shape, pre_axis_set);
d_data = d_data * b_scale;
}
auto d_mean = make_shared<op::Broadcast>(
builder::mean(-d_data, post_reduction_axes), shape, post_axis_set);
auto d_stddev =
norm * make_shared<op::Broadcast>(
builder::mean(-d_data * norm, post_reduction_axes), shape, post_axis_set);
d_data = d_data + d_mean + d_stddev;
NodeVector retval;
retval.emplace_back(d_data);
// Get gradients for affine
if (m_use_affine)
{
std::vector<size_t> pre_reduction_axes(n_axis);
std::iota(pre_reduction_axes.begin(), pre_reduction_axes.end(), 0);
auto d_bias = make_shared<op::Sum>(delta, pre_reduction_axes);
auto d_scale = make_shared<op::Sum>(delta * norm, pre_reduction_axes);
if (scale_flattened)
{
std::vector<size_t> flatten_axes_vector(shape.size() - n_axis);
std::iota(flatten_axes_vector.begin(), flatten_axes_vector.end(), 0);
AxisVector flatten_axes = AxisVector(flatten_axes_vector);
Shape reshape_shape(shape.begin() + m_begin_norm_axis, shape.end());
size_t reshape_size = shape_size(reshape_shape);
auto flatten_d_scale =
make_shared<op::Reshape>(d_scale, flatten_axes, Shape{reshape_size});
auto flatten_d_bias =
make_shared<op::Reshape>(d_bias, flatten_axes, Shape{reshape_size});
retval.emplace_back(flatten_d_scale);
retval.emplace_back(flatten_d_bias);
}
else
{
retval.emplace_back(d_scale);
retval.emplace_back(d_bias);
}
}
return retval;
}
shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() < 2 || new_args.size() > 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
if (new_args.size() == 5)
{
return make_shared<LayerNormBackprop>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_begin_norm_axis,
m_epsilon);
}
else if (new_args.size() == 4)
{
return make_shared<LayerNormBackprop>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
m_begin_norm_axis,
m_epsilon);
}
else if (new_args.size() == 3)
{
return make_shared<LayerNormBackprop>(
new_args.at(0), new_args.at(1), new_args.at(2), m_begin_norm_axis, m_epsilon);
}
else
{
return make_shared<LayerNormBackprop>(
new_args.at(0), new_args.at(1), m_begin_norm_axis, m_epsilon);
}
}
void op::LayerNormBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
const PartialShape& data_shape = get_input_partial_shape(0);
Rank data_rank = data_shape.rank();
int64_t d_rank = -1;
int64_t n_axis = -1;
if (data_rank.is_static())
{
d_rank = static_cast<int64_t>(data_rank);
n_axis = m_begin_norm_axis >= 0 ? m_begin_norm_axis : d_rank + m_begin_norm_axis;
NODE_VALIDATION_CHECK(
this, n_axis >= 0 && n_axis < d_rank, "begin_norm_axis is out of range");
const PartialShape& delta_shape = get_input_partial_shape(1);
Rank delta_rank = delta_shape.rank();
NODE_VALIDATION_CHECK(this,
delta_rank.is_dynamic() || static_cast<int64_t>(delta_rank) == d_rank,
"Delta rank is incorrect");
if (m_use_stats)
{
const PartialShape& mean_shape = get_input_partial_shape(2);
const PartialShape& var_shape = get_input_partial_shape(3);
Rank mean_rank = mean_shape.rank();
Rank var_rank = var_shape.rank();
if (mean_rank.is_static() && var_rank.is_static())
{
int64_t m_rank = static_cast<int64_t>(mean_rank);
int64_t v_rank = static_cast<int64_t>(var_rank);
NODE_VALIDATION_CHECK(this,
m_rank == v_rank && m_rank == n_axis,
"Mean and/or variance rank is incorrect");
}
}
if (m_use_affine)
{
const PartialShape& scale_shape = get_input_partial_shape(m_use_stats ? 4 : 2);
Rank scale_rank = scale_shape.rank();
if (scale_rank.is_static())
{
int64_t s_rank = static_cast<int64_t>(scale_rank);
NODE_VALIDATION_CHECK(
this, (s_rank == (d_rank - n_axis)) || s_rank == 1, "Scale rank is incorrect");
}
}
}
if (m_use_affine)
{
set_output_size(3);
// output shape: data_shape[begin_norm_axis:]
if (d_rank > 0)
{
std::vector<Dimension> affine_dim;
for (int64_t i = n_axis; i < d_rank; i++)
{
affine_dim.emplace_back(data_shape[i]);
}
PartialShape affine_shape(affine_dim);
set_output_type(1, input_element_type, affine_shape);
set_output_type(2, input_element_type, affine_shape);
}
else // set shape to dynamic
{
set_output_type(1, input_element_type, PartialShape::dynamic());
set_output_type(2, input_element_type, PartialShape::dynamic());
}
}
PartialShape norm_shape{data_shape};
set_output_type(0, input_element_type, norm_shape);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Layer Normalization
///
class LayerNorm : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LayerNorm", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LayerNorm() = default;
/// \brief Constructs an LayerNorm operation.
///
/// \param data Input tensor
/// \param scale Scale tensor
/// \param bias Bias tensor
/// \param keep_stats Generated addition output mean and variance, default true
/// \param begin_norm_axis Axis where normalization starts, default - -1
/// \param epsilon Small number to add for stability of rsqrt, default 1e-5
LayerNorm(const Output<Node>& data,
const Output<Node>& scale,
const Output<Node>& bias,
bool keep_stats = true,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
LayerNorm(const Output<Node>& data,
bool keep_stats = true,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_keep_stats() const { return m_keep_stats; }
bool get_use_affine() const { return m_use_affine; }
double get_epsilon() const { return m_epsilon; }
int64_t get_begin_norm_axis() const { return m_begin_norm_axis; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
bool m_keep_stats{true};
bool m_use_affine{true};
int64_t m_begin_norm_axis{1};
double m_epsilon{1e-5};
};
/// \brief Layer Normalization Backprop
///
class LayerNormBackprop : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LayerNormBackprop", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LayerNormBackprop() = default;
/// \brief Constructs an LayerNormBackprop operation.
///
/// \param data Input tensor
/// \param mean Mean tensor from fprop
/// \param variance Variance tensor from fprop
/// \param delta Delta tensor
/// \param scale Scale tensor
/// \param begin_norm_axis Axis where normalization starts, default - -1
/// \param epsilon Small number to add for stability of rsqrt, default 1e-5
LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& scale,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& mean,
const Output<Node>& variance,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
const Output<Node>& scale,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
LayerNormBackprop(const Output<Node>& data,
const Output<Node>& delta,
int64_t begin_norm_axis = 1,
double epsilon = 1e-5);
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_use_stats() const { return m_use_stats; }
bool get_use_affine() const { return m_use_affine; }
double get_epsilon() const { return m_epsilon; }
int64_t get_begin_norm_axis() const { return m_begin_norm_axis; }
private:
bool m_use_stats{true};
bool m_use_affine{true};
int64_t m_begin_norm_axis{1};
double m_epsilon{1e-5};
};
}
}
......@@ -37,6 +37,8 @@ NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
NGRAPH_OP(LayerNorm, ngraph::op)
NGRAPH_OP(LayerNormBackprop, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(MVN, ngraph::op)
......
......@@ -18,6 +18,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/fused/group_conv.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp"
......@@ -90,7 +91,13 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
// We apply the same general-purposes passes as the CPU backend.
pass_manager.register_pass<ngraph::pass::FusedOpDecomposition>([](const Node& node) -> bool {
if (node.description() == ngraph::op::GroupConvolution().description())
{
return true;
}
else if (node.description() == ngraph::op::LayerNorm().description())
{
return true;
}
return false;
});
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
......
......@@ -282,5 +282,10 @@ random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
# Fused op test fails on mac
layer_norm_affine_stats
layer_norm_bprop_affine_stats
layer_norm_bprop_affine
# shapes with zeros dimensions like (5, 0, 5) not supported in PlaidML backend
dyn_replace_slice
\ No newline at end of file
dyn_replace_slice
......@@ -80,6 +80,7 @@
#include "ngraph/op/fused/group_conv_transpose.hpp"
#include "ngraph/op/fused/gru_cell.hpp"
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "ngraph/op/fused/layer_norm.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/fused/matmul.hpp"
#include "ngraph/op/fused/mvn.hpp"
......@@ -1404,7 +1405,51 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::HardSigmoid>(args[0], alpha, beta);
break;
}
case OP_TYPEID::LayerNorm:
{
auto keep_stats = node_js.at("keep_stats").get<bool>();
auto use_affine = node_js.at("use_affine").get<bool>();
auto epsilon = node_js.at("epsilon").get<double>();
auto begin_norm_axis = node_js.at("begin_norm_axis").get<int64_t>();
if (use_affine)
{
node = make_shared<op::LayerNorm>(
args[0], args[1], args[2], keep_stats, begin_norm_axis, epsilon);
}
else
{
node = make_shared<op::LayerNorm>(args[0], keep_stats, begin_norm_axis, epsilon);
}
break;
}
case OP_TYPEID::LayerNormBackprop:
{
auto use_stats = node_js.at("use_stats").get<bool>();
auto use_affine = node_js.at("use_affine").get<bool>();
auto epsilon = node_js.at("epsilon").get<double>();
auto begin_norm_axis = node_js.at("begin_norm_axis").get<int64_t>();
if (use_stats && use_affine)
{
node = make_shared<op::LayerNormBackprop>(
args[0], args[1], args[2], args[3], args[4], begin_norm_axis, epsilon);
}
else if (use_stats)
{
node = make_shared<op::LayerNormBackprop>(
args[0], args[1], args[2], args[3], begin_norm_axis, epsilon);
}
else if (use_affine)
{
node = make_shared<op::LayerNormBackprop>(
args[0], args[1], args[2], begin_norm_axis, epsilon);
}
else
{
node =
make_shared<op::LayerNormBackprop>(args[0], args[1], begin_norm_axis, epsilon);
}
break;
}
case OP_TYPEID::Less:
{
node = make_shared<op::Less>(
......@@ -2687,6 +2732,24 @@ json JSONSerializer::serialize_node(const Node& n)
node["beta"] = tmp->get_beta();
break;
}
case OP_TYPEID::LayerNorm:
{
auto tmp = dynamic_cast<const op::LayerNorm*>(&n);
node["keep_stats"] = tmp->get_keep_stats();
node["use_affine"] = tmp->get_use_affine();
node["epsilon"] = tmp->get_epsilon();
node["begin_norm_axis"] = tmp->get_begin_norm_axis();
break;
}
case OP_TYPEID::LayerNormBackprop:
{
auto tmp = dynamic_cast<const op::LayerNormBackprop*>(&n);
node["use_stats"] = tmp->get_use_stats();
node["use_affine"] = tmp->get_use_affine();
node["epsilon"] = tmp->get_epsilon();
node["begin_norm_axis"] = tmp->get_begin_norm_axis();
break;
}
case OP_TYPEID::Less:
{
auto tmp = dynamic_cast<const op::Less*>(&n);
......
......@@ -124,6 +124,7 @@ set(SRC
type_prop/gru_cell.cpp
type_prop/hard_sigmoid.cpp
type_prop/index_reduction.cpp
type_prop/layer_norm.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/matmul.cpp
......@@ -272,6 +273,7 @@ set(MULTI_TEST_SRC
backend/gather.in.cpp
backend/gelu.in.cpp
backend/generate_mask.in.cpp
backend/layer_norm.in.cpp
backend/log.in.cpp
backend/logical_and.in.cpp
backend/logical_or.in.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#endif
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, layer_norm_affine_stats)
{
auto p_data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto p_scale = make_shared<op::Parameter>(element::f32, Shape{4});
auto p_bias = make_shared<op::Parameter>(element::f32, Shape{4});
auto ln = make_shared<op::LayerNorm>(p_data, p_scale, p_bias);
auto f = make_shared<Function>(ln->outputs(), ParameterVector{p_data, p_scale, p_bias});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create tensors for input
auto data = backend->create_tensor(element::f32, Shape{2, 4});
auto scale = backend->create_tensor(element::f32, Shape{4});
auto bias = backend->create_tensor(element::f32, Shape{4});
// Fill in input tensors
vector<float> d_input{-4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f};
copy_data(data, d_input);
vector<float> s_input{-1.0f, 1.0f, 2.0f, 3.0f};
copy_data(scale, s_input);
vector<float> b_input{-4.0f, -3.0f, -2.0f, -1.0f};
copy_data(bias, b_input);
// Create tensors for output
auto norm = backend->create_tensor(element::f32, Shape{2, 4});
auto mean = backend->create_tensor(element::f32, Shape{2});
auto var = backend->create_tensor(element::f32, Shape{2});
// Expected results (Manually computed)
vector<float> exp_norm{-2.658364534378051758f,
-3.447211742401123047f,
-1.105576276779174805f,
3.024906158447265625f,
-2.658364534378051758f,
-3.447211742401123047f,
-1.105576276779174805f,
3.024906158447265625f};
vector<float> exp_mean{-2.5f, 1.5f};
vector<float> exp_var{1.25f, 1.25f};
auto handle = backend->compile(f);
handle->call_with_validate({norm, mean, var}, {data, scale, bias});
EXPECT_TRUE(test::all_close_f(exp_norm, read_vector<float>(norm)));
EXPECT_TRUE(test::all_close_f(exp_mean, read_vector<float>(mean)));
EXPECT_TRUE(test::all_close_f(exp_var, read_vector<float>(var)));
}
NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_affine_stats)
{
auto p_data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto p_delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto p_mean = make_shared<op::Parameter>(element::f32, Shape{2});
auto p_var = make_shared<op::Parameter>(element::f32, Shape{2});
auto p_scale = make_shared<op::Parameter>(element::f32, Shape{4});
auto lnb = make_shared<op::LayerNormBackprop>(p_data, p_delta, p_mean, p_var, p_scale);
auto f = make_shared<Function>(lnb->outputs(),
ParameterVector{p_data, p_delta, p_mean, p_var, p_scale});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create tensors for input
auto data = backend->create_tensor(element::f32, Shape{2, 4});
auto delta = backend->create_tensor(element::f32, Shape{2, 4});
auto mean = backend->create_tensor(element::f32, Shape{2});
auto var = backend->create_tensor(element::f32, Shape{2});
auto scale = backend->create_tensor(element::f32, Shape{4});
// Fill in input tensors
vector<float> d_input{-4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f};
copy_data(data, d_input);
vector<float> dt_input{0.1f, -0.1f, 0.2f, -0.2f, 0.1f, -0.1f, 0.2f, -0.2f};
copy_data(delta, dt_input);
vector<float> s_input{-1.0f, 1.0f, 2.0f, 3.0f};
copy_data(scale, s_input);
vector<float> m_input{-2.5f, 1.5f};
copy_data(mean, m_input);
vector<float> v_input{1.25f, 1.25f};
copy_data(var, v_input);
// Create tensors for output
auto d_data = backend->create_tensor(element::f32, Shape{2, 4});
auto d_scale = backend->create_tensor(element::f32, Shape{4});
auto d_bias = backend->create_tensor(element::f32, Shape{4});
// Expected results (Manually compute)
vector<float> exp_d_data{-0.1341624855995178223f,
-0.04472083225846290588f,
0.4919326305389404297f,
-0.31304931640625f,
-0.1341624855995178223f,
-0.04472083225846290588f,
0.4919326305389404297f,
-0.31304931640625f};
vector<float> exp_d_scale{-0.2683270871639251709f,
0.08944236487150192261f,
0.1788847297430038452f,
-0.5366541743278503418f};
vector<float> exp_d_bias{0.2f, -0.2f, 0.4f, -0.4f};
auto handle = backend->compile(f);
handle->call_with_validate({d_data, d_scale, d_bias}, {data, delta, mean, var, scale});
EXPECT_TRUE(test::all_close_f(exp_d_data, read_vector<float>(d_data)));
EXPECT_TRUE(test::all_close_f(exp_d_scale, read_vector<float>(d_scale)));
EXPECT_TRUE(test::all_close_f(exp_d_bias, read_vector<float>(d_bias)));
}
NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_affine)
{
auto p_data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto p_delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto p_scale = make_shared<op::Parameter>(element::f32, Shape{4});
auto lnb = make_shared<op::LayerNormBackprop>(p_data, p_delta, p_scale);
auto f = make_shared<Function>(lnb->outputs(), ParameterVector{p_data, p_delta, p_scale});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create tensors for input
auto data = backend->create_tensor(element::f32, Shape{2, 4});
auto delta = backend->create_tensor(element::f32, Shape{2, 4});
auto scale = backend->create_tensor(element::f32, Shape{4});
// Fill in input tensors
vector<float> d_input{-4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f};
copy_data(data, d_input);
vector<float> dt_input{0.1f, -0.1f, 0.2f, -0.2f, 0.1f, -0.1f, 0.2f, -0.2f};
copy_data(delta, dt_input);
vector<float> s_input{-1.0f, 1.0f, 2.0f, 3.0f};
copy_data(scale, s_input);
// Create tensors for output
auto d_data = backend->create_tensor(element::f32, Shape{2, 4});
auto d_scale = backend->create_tensor(element::f32, Shape{4});
auto d_bias = backend->create_tensor(element::f32, Shape{4});
// Expected results (Manually computed)
vector<float> exp_d_data{-0.1341624855995178223f,
-0.04472083225846290588f,
0.4919326305389404297f,
-0.31304931640625f,
-0.1341624855995178223f,
-0.04472083225846290588f,
0.4919326305389404297f,
-0.31304931640625f};
vector<float> exp_d_scale{-0.2683270871639251709f,
0.08944236487150192261f,
0.1788847297430038452f,
-0.5366541743278503418f};
vector<float> exp_d_bias{0.2f, -0.2f, 0.4f, -0.4f};
auto handle = backend->compile(f);
handle->call_with_validate({d_data, d_scale, d_bias}, {data, delta, scale});
EXPECT_TRUE(test::all_close_f(exp_d_data, read_vector<float>(d_data)));
EXPECT_TRUE(test::all_close_f(exp_d_scale, read_vector<float>(d_scale)));
EXPECT_TRUE(test::all_close_f(exp_d_bias, read_vector<float>(d_bias)));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, layer_norm_element_type)
{
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto scale = make_shared<op::Parameter>(element::f32, Shape{4});
auto bias = make_shared<op::Parameter>(element::f32, Shape{4});
try
{
auto ln = make_shared<op::LayerNorm>(data, scale, bias);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument element type must be f16, bf16, f32, f64 or dynamic"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_begin_norm_axis)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto scale = make_shared<op::Parameter>(element::f32, Shape{4});
auto bias = make_shared<op::Parameter>(element::f32, Shape{4});
try
{
auto ln = make_shared<op::LayerNorm>(data, scale, bias, false, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect begin norm axis";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("begin_norm_axis is out of range"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_affine_rank)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto scale = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto bias = make_shared<op::Parameter>(element::f32, Shape{4});
try
{
auto ln = make_shared<op::LayerNorm>(data, scale, bias);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect affine ranks";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Scale and/or bias rank is incorrect"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_bprop_element_type)
{
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto lnb = make_shared<op::LayerNormBackprop>(data, delta);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument element type must be f16, bf16, f32, f64 or dynamic"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_bprop_begin_norm_axis)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto lnb = make_shared<op::LayerNormBackprop>(data, delta, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect begin norm axis";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("begin_norm_axis is out of range"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_bprop_delta)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{4});
try
{
auto lnb = make_shared<op::LayerNormBackprop>(data, delta);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect delta rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Delta rank is incorrect"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_bprop_stats)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto mean = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto variance = make_shared<op::Parameter>(element::f32, Shape{2});
try
{
auto lnb = make_shared<op::LayerNormBackprop>(data, delta, mean, variance);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect stats rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Mean and/or variance rank is incorrect"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, layer_norm_bprop_affine)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{2, 4});
auto scale = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto lnb = make_shared<op::LayerNormBackprop>(data, delta, scale);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect affine rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Scale rank is incorrect"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment