Commit b18cb73d authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[Spec] Implement support for axes input of LRN op in reference implementation (#3454)

* Axes input was added to LRN

* Unit tests for axes shape check were added

* LRN node deserialization was updated

* Fixed EOF and clang style applied

* Changed Constant to Parameter type in unit tests

* Expanded LRN reference ingterface

* Fixed LRN assert description

* Fixed passing arguments

* Reference implementation for one axis

* Implementation for channel

* Implementation for hw

* working on recurence version

* Implemented recurence version for hw

* Reference implementation code refactor

* Fixed ref LRN implementation and added tests

* Added 6D unit test

* Clang styles applied

* Code review remarks introduced

* Support for dynamic shape of axes input

* Clang styles applied

* Code review remarks introduced

* Added checking if axes values are in correct range

* Clang styles applied

* Removed redundant include

* Code review remarks introduced
parent 5f53b417
......@@ -24,7 +24,7 @@ using namespace ngraph;
const string op::LRN::type_name{"LRN"};
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: LRN(arg, op::Constant::create(element::i32, Shape{1}, {1}), alpha, beta, bias, size)
: LRN(arg, op::Constant::create(element::i64, Shape{1}, {1}), alpha, beta, bias, size)
{
}
......@@ -43,6 +43,17 @@ op::LRN::LRN(const Output<Node>& arg,
constructor_validate_and_infer_types();
}
AxisSet op::LRN::get_reduction_axes() const
{
AxisSet axes{1}; // channel axis as default
auto axes_input_node = input_value(1).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node))
{
axes = const_op->get_axis_set_val();
}
return axes;
}
void op::LRN::validate_and_infer_types()
{
element::Type arg_type = get_input_element_type(0);
......@@ -50,30 +61,60 @@ void op::LRN::validate_and_infer_types()
set_output_type(0, arg_type, arg_shape);
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& axes_shape = get_input_partial_shape(1);
const auto input_shape_rank = input_shape.rank();
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_dynamic() ||
input_shape_rank.is_dynamic() ||
static_cast<size_t>(input_shape.rank()) >= 3,
"Argument must have rank >= 3 (argument shape: ",
input_shape,
").");
NODE_VALIDATION_CHECK(this, axes_shape.is_static(), "Input axes must be static.");
PartialShape axes_shape{PartialShape::dynamic()};
if (get_input_partial_shape(1).is_static())
{
axes_shape = get_input_partial_shape(1);
}
auto axes_rank = axes_shape.rank();
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(axes_shape.rank()) == 1,
"Input axes must have rank equals 1 (axes shape: ",
axes_shape,
axes_rank.compatible(1),
"Input axes must have rank equals 1 (axes_rank: ",
axes_rank,
").");
NODE_VALIDATION_CHECK(
this,
static_cast<size_t>(axes_shape[0]) >= 1 &&
static_cast<size_t>(axes_shape[0]) <= static_cast<size_t>(input_shape.rank()),
"Number of elements of axes must be >= 1 and <= argument rank (axes_shape[0]: ",
static_cast<size_t>(axes_shape[0]) >= 0 &&
static_cast<size_t>(axes_shape[0]) <= static_cast<size_t>(input_shape_rank),
"Number of elements of axes must be >= 0 and <= argument rank (axes_shape[0]: ",
axes_shape[0],
").");
if (input_shape_rank.is_static())
{
const auto reduction_axes = get_reduction_axes();
for (auto axis : reduction_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(input_shape_rank),
"Reduction axis (",
axis,
") is out of bounds ",
"(argument shape: ",
input_shape,
", reduction axes: ",
reduction_axes,
")");
}
}
const auto& axes_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
axes_type.compatible(element::Type_t::i64),
"Axes input must have element type i64 (axes type: ",
axes_type,
").");
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -69,6 +69,8 @@ namespace ngraph
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
AxisSet get_reduction_axes() const;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
......@@ -64,11 +64,13 @@ namespace ngraph
}
else
{
AxisSet axes = lrn->get_reduction_axes();
double alpha = lrn->get_alpha();
double beta = lrn->get_beta();
double bias = lrn->get_bias();
double nsize = lrn->get_nsize();
Shape arg_shape = args[0].get_shape();
Shape axes_shape = args[1].get_shape();
auto element_type = lrn->get_element_type();
if (element_type == element::f32)
......@@ -78,12 +80,14 @@ namespace ngraph
beta,
bias,
arg_shape,
axes_shape,
nsize,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::lrn<float>(
static_cast<float*>(ctx->buffer_data[arg_buffer_index]),
axes,
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
arg_shape,
alpha,
......@@ -99,12 +103,14 @@ namespace ngraph
beta,
bias,
arg_shape,
axes_shape,
nsize,
arg_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::lrn<double>(
static_cast<double*>(ctx->buffer_data[arg_buffer_index]),
axes,
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
arg_shape,
alpha,
......
......@@ -13,3 +13,11 @@ send_recv_ring
# param not supported in CPU backend
group_conv_data_dilation
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
......@@ -231,3 +231,12 @@ gelu_backprop_factor_f32
gelu_backprop_factor_f64
backwards_gelu_f32
backwards_gelu_f64
logical_xor
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
......@@ -4,3 +4,11 @@ batch_norm_inference_f64
batch_norm_inference_f32
send_recv
send_recv_ring
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
......@@ -141,3 +141,11 @@ quantized_dot_int32_output
# Need to update implementation
divide_python_rounding_int32
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
......@@ -998,6 +998,7 @@ private:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<const T>(),
lrn->get_reduction_axes(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
lrn->get_alpha(),
......
......@@ -339,3 +339,11 @@ avg_pool_3d_uneven_strided_padded
rnn_cell_activation_function
gru_cell_bias_clip
gru_cell_linear_before_reset
# axes input param not supported
lrn_across_h
lrn_across_hw
lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
......@@ -16,6 +16,7 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
......@@ -28,8 +29,43 @@ namespace ngraph
{
namespace reference
{
template <typename T>
static void sum_region_across_axes(const T* arg,
size_t current_axis_index,
const std::vector<size_t>& axes,
Coordinate& sum_coord,
T& square_sum,
const std::vector<size_t>& begin_area,
const std::vector<size_t>& end_area,
const CoordinateTransform& input_transform)
{
// all nested axes were visited
if (current_axis_index == axes.size())
{
square_sum += arg[input_transform.index(sum_coord)] *
arg[input_transform.index(sum_coord)];
return;
}
auto current_axis = axes[current_axis_index];
for (auto current_axis_coord = begin_area[current_axis];
current_axis_coord < end_area[current_axis];
++current_axis_coord)
{
sum_coord.at(current_axis) = current_axis_coord;
sum_region_across_axes(arg,
current_axis_index + 1,
axes,
sum_coord,
square_sum,
begin_area,
end_area,
input_transform);
}
}
template <typename T>
void lrn(const T* arg,
const AxisSet& axes,
T* out,
const Shape& arg_shape,
double dalpha,
......@@ -41,25 +77,33 @@ namespace ngraph
T beta = static_cast<T>(dbeta);
T bias = static_cast<T>(dbias);
std::vector<size_t> begin_area(arg_shape.size());
std::vector<size_t> end_area(arg_shape.size());
CoordinateTransform input_transform(arg_shape);
const size_t CHANNEL_DIM = 1;
const size_t MAX_C = arg_shape.at(CHANNEL_DIM);
for (const Coordinate& in_coord : input_transform)
{
size_t c = in_coord.at(CHANNEL_DIM);
T square_sum = 0;
for (size_t i = c; i < c + size; i++)
// area determined by in_coord local neighborhood
for (const auto& axis_coord : axes)
{
if (i < (size - 1) / 2)
continue;
if (i >= MAX_C + (size - 1) / 2)
continue;
auto sum_coord = in_coord;
sum_coord.at(CHANNEL_DIM) = i - (size - 1) / 2;
square_sum += arg[input_transform.index(sum_coord)] *
arg[input_transform.index(sum_coord)];
begin_area[axis_coord] =
std::max<int>(0, in_coord.at(axis_coord) - (size - 1) / 2);
end_area[axis_coord] = std::min<int>(
arg_shape.at(axis_coord), in_coord.at(axis_coord) + (size - 1) / 2 + 1);
}
T square_sum = 0;
auto sum_coord = in_coord;
auto axes_vec = std::vector<size_t>(axes.begin(), axes.end());
sum_region_across_axes(arg,
0,
axes_vec,
sum_coord,
square_sum,
begin_area,
end_area,
input_transform);
T x = arg[input_transform.index(in_coord)];
out[input_transform.index(in_coord)] =
x / (std::pow(bias + (alpha / size) * square_sum, beta));
......
......@@ -24,6 +24,8 @@
#endif
// clang-format on
#include <numeric>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
......@@ -37,7 +39,7 @@ using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, lrn)
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_channel)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
......@@ -45,6 +47,7 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn)
double beta = 0.5;
double bias = 1;
size_t size = 3;
// lrn is performed across channel as default
auto lrn = make_shared<op::LRN>(A, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
......@@ -72,3 +75,221 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn)
0.7720487f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_h)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{2});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.0f,
0.7071068f,
0.5345225f,
0.8017837f,
0.6172134f,
0.7715167f,
0.6469966f,
0.7548294f,
0.6620847f,
0.7448453f,
0.671156f,
0.7382717f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_hw)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{2, 3});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.0f,
0.7071068f,
0.5345225f,
0.8017837f,
0.6172134f,
0.7715167f,
0.6469966f,
0.7548294f,
0.6620847f,
0.7448453f,
0.671156f,
0.7382717f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_all_dims)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.0f,
0.0638877f,
0.0888231f,
0.1332347f,
0.1949481f,
0.2436851f,
0.3833259f,
0.4472136f,
0.3552925f,
0.399704f,
0.4873702f,
0.5361072f};
EXPECT_TRUE(
test::all_close_f(expected, read_vector<float>(result), DEFAULT_FLOAT_TOLERANCE_BITS + 1));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_nw)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 3});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.0f,
0.140028f,
0.2407717f,
0.3144855f,
0.3698001f,
0.4123931f,
0.9863939f,
0.9801961f,
0.9630868f,
0.9434564f,
0.9245003f,
0.9072647f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_across_empty)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{0}, vector<int64_t>{});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{
0.0f,
0.7071068f,
0.8944272f,
0.9486833f,
0.9701425f,
0.9805807f,
0.9863939f,
0.9899495f,
0.9922779f,
0.9938837f,
0.9950372f,
0.9958932f,
};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_6D_across_2_axes)
{
Shape shape{2, 3, 2, 2, 1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{2, 3});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args(24);
std::iota(std::begin(args), std::end(args), 0);
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.0f, 0.2581989f, 0.5163978f, 0.7745967f, 0.3549426f, 0.4436783f,
0.5324139f, 0.6211495f, 0.4175966f, 0.4697962f, 0.5219957f, 0.5741953f,
0.4426267f, 0.4795122f, 0.5163978f, 0.5532833f, 0.4560274f, 0.4845291f,
0.5130308f, 0.5415326f, 0.4643635f, 0.4875816f, 0.5107998f, 0.534018f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
#!/usr/bin/env python
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import copy
import numpy as np
def LRN(input, size=3, bias=1.0, alpha=3.0, beta=0.5):
output = copy.deepcopy(input)
N = input.shape[0]
C = input.shape[1]
H = input.shape[2]
W = input.shape[3]
for n in range(N):
for c in range(C):
for h in range(H):
begin_h = max(0, h - (size-1)/2)
end_h = min(H, h + (size-1)/2 + 1)
for w in range(W):
begin_w = max(0, w - (size-1)/2)
end_w = min(W, w + (size-1)/2 + 1)
patch = input[n, c, begin_h:end_h, begin_w:end_w]
output[n, c, h, w] /= (
np.power(bias + (alpha/size) * np.sum(patch * patch), beta))
return output
input = np.arange(0, 12, 1).reshape(2, 3, 2, 1).astype(np.float32)
result = LRN(input)
for elem in np.nditer(result):
print(str(round(elem, 7)) + "f, ")
......@@ -70,7 +70,29 @@ TEST(type_prop, lrn_invalid_axes_rank)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Number of elements of axes must be >= 1 and <= argument rank"));
std::string("Number of elements of axes must be >= 0 and <= argument rank"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, lrn_incorrect_axes_value)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto axes = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{3, 4});
double alpha = 0.1f, beta = 0.2f, bias = 0.3f;
size_t size = 3;
try
{
auto lrn = make_shared<op::LRN>(data, axes, alpha, beta, bias, size);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Reduction axis ("));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment