Commit 83611106 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[Spec] Add axes input to LRN (#3374)

* Axes input was added to LRN

* Unit tests for axes shape check were added

* LRN node deserialization was updated

* Fixed EOF and clang style applied

* Changed Constant to Parameter type in unit tests

* Fixed LRN assert description

* Fixed copy_with_new_args

* Clang style applied
parent 34c084e3
......@@ -15,6 +15,8 @@
//*****************************************************************************
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
using namespace std;
using namespace ngraph;
......@@ -22,7 +24,17 @@ using namespace ngraph;
const string op::LRN::type_name{"LRN"};
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
: LRN(arg, op::Constant::create(element::i32, Shape{1}, {1}), alpha, beta, bias, size)
{
}
op::LRN::LRN(const Output<Node>& arg,
const Output<Node>& axes,
double alpha,
double beta,
double bias,
size_t size)
: Op({arg, axes})
, m_alpha(alpha)
, m_beta(beta)
, m_bias(bias)
......@@ -33,9 +45,12 @@ op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, si
void op::LRN::validate_and_infer_types()
{
UnaryElementwiseArithmetic::validate_and_infer_types();
element::Type arg_type = get_input_element_type(0);
PartialShape arg_shape = get_input_partial_shape(0);
set_output_type(0, arg_type, arg_shape);
const PartialShape& input_shape = get_input_partial_shape(0);
const PartialShape& axes_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_dynamic() ||
......@@ -43,12 +58,28 @@ void op::LRN::validate_and_infer_types()
"Argument must have rank >= 3 (argument shape: ",
input_shape,
").");
NODE_VALIDATION_CHECK(this, axes_shape.is_static(), "Input axes must be static.");
NODE_VALIDATION_CHECK(this,
static_cast<size_t>(axes_shape.rank()) == 1,
"Input axes must have rank equals 1 (axes shape: ",
axes_shape,
").");
NODE_VALIDATION_CHECK(
this,
static_cast<size_t>(axes_shape[0]) >= 1 &&
static_cast<size_t>(axes_shape[0]) <= static_cast<size_t>(input_shape.rank()),
"Number of elements of axes must be >= 1 and <= argument rank (axes_shape[0]: ",
axes_shape[0],
").");
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::LRN>(new_args.at(0), m_alpha, m_beta, m_bias, m_size);
return make_shared<op::LRN>(new_args.at(0), new_args.at(1), m_alpha, m_beta, m_bias, m_size);
}
void op::LRN::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -16,7 +16,7 @@
#pragma once
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
......@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[n, c, d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[n, c, d_1,\dots,d_n] = \frac{N[n,i,d_1,\dots,d_n]}{ (bias + alpha * (\sum_{i=max(0,(nsize-1)/2)}^{min(C, (nsize-1)/2)+1} N[n,i,d_1,\dots,d_n]^{2}) ^ {2})}\f$ |
class LRN : public util::UnaryElementwiseArithmetic
class LRN : public Op
{
public:
NGRAPH_API
......@@ -48,6 +48,13 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
LRN(const Output<Node>& arg,
const Output<Node>& axes,
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
......
......@@ -1363,7 +1363,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto beta = node_js.at("beta").get<double>();
auto bias = node_js.at("bias").get<double>();
auto nsize = node_js.at("nsize").get<size_t>();
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
node = make_shared<op::LRN>(args[0], args[1], alpha, beta, bias, nsize);
break;
}
case OP_TYPEID::LSTMCell:
......
......@@ -111,6 +111,7 @@ set(SRC
type_prop/gru_cell.cpp
type_prop/hard_sigmoid.cpp
type_prop/index_reduction.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, lrn_invalid_arg_rank)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2});
double alpha = 0.1f, beta = 0.2f, bias = 0.3f;
size_t size = 3;
try
{
auto lrn = make_shared<op::LRN>(data, alpha, beta, bias, size);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument must have rank >= 3"));
}
}
TEST(type_prop, lrn_invalid_axes_rank)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
auto axes = make_shared<op::Parameter>(element::f32, Shape{1, 2});
double alpha = 0.1f, beta = 0.2f, bias = 0.3f;
size_t size = 3;
try
{
auto lrn = make_shared<op::LRN>(data, axes, alpha, beta, bias, size);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input axes must have rank equals 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
axes = make_shared<op::Parameter>(element::f32, Shape{5});
try
{
auto lrn = make_shared<op::LRN>(data, axes, alpha, beta, bias, size);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input tensor rank not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Number of elements of axes must be >= 1 and <= argument rank"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment