Commit f2a8f6e5 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Relax check on LRN for rank requirement to be >=3 (#3952)

*  relax check for LRN for requirement rank should be >=3

* rename unit test names

* - Disable lrn unit test with axes for CPU backend

* remove outdated unit test on rank requirement from type_prop

* - disable newly added lrn unit test in plaidMl
parent ee8b9366
......@@ -64,13 +64,6 @@ void op::LRN::validate_and_infer_types()
const PartialShape& input_shape = get_input_partial_shape(0);
const auto input_shape_rank = input_shape.rank();
NODE_VALIDATION_CHECK(this,
input_shape_rank.is_dynamic() ||
static_cast<size_t>(input_shape.rank()) >= 3,
"Argument must have rank >= 3 (argument shape: ",
input_shape,
").");
PartialShape axes_shape{PartialShape::dynamic()};
if (get_input_partial_shape(1).is_static())
{
......
......@@ -21,6 +21,8 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
lrn_2d_across_empty
lrn_2d_across_outermost_axis
# ONNX TopK with dynamic K
top_k_opset_10
......@@ -301,6 +301,8 @@ lrn_across_all_dims
lrn_across_nw
lrn_across_empty
lrn_6D_across_2_axes
lrn_2d_across_empty
lrn_2d_across_outermost_axis
# RandomUniform not supported in PlaidML backend
random_uniform_all_static_seed_unused
......
......@@ -293,3 +293,90 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_6D_across_2_axes)
0.5130308f, 0.5415326f, 0.4643635f, 0.4875816f, 0.5107998f, 0.534018f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_2d_across_empty)
{
Shape shape{12};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{0}, vector<int64_t>{});
double alpha = 3;
double beta = 0.5;
double bias = 1;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{
0.0f,
0.7071068f,
0.8944272f,
0.9486833f,
0.9701425f,
0.9805807f,
0.9863939f,
0.9899495f,
0.9922779f,
0.9938837f,
0.9950372f,
0.9958932f,
};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn_2d_across_outermost_axis)
{
Shape shape{6, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto axes = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{0});
double alpha = 0.0002;
double beta = 0.5;
double bias = 2.0;
size_t size = 3;
auto lrn = make_shared<op::LRN>(A, axes, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.64915806f,
0.21213771f,
-1.48256505f,
-1.41040838f,
0.58189541f,
0.11432108f,
-0.22993855f,
-0.13325502f,
-0.03083259f,
-0.48450908f,
0.50342429f,
-0.99551708f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<float> expected{0.45900404f,
0.14999892f,
-1.04828012f,
-0.99727529f,
0.41144446f,
0.08083449f,
-0.16259004f,
-0.09422511f,
-0.02180192f,
-0.34259823f,
0.35597473f,
-0.70393407f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result), 23));
}
......@@ -21,23 +21,6 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, lrn_invalid_arg_rank)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2});
double alpha = 0.1, beta = 0.2, bias = 0.3;
size_t size = 3;
try
{
auto lrn = make_shared<op::LRN>(data, alpha, beta, bias, size);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument must have rank >= 3"));
}
}
TEST(type_prop, lrn_invalid_axes_rank)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment