Commit 237c4803 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

LRN (#1282)

* lrn init

* fix comment

* mkldnn lrn (#1295)

* add serializer + fix compiler warnings
parent 83a9d252
......@@ -64,6 +64,7 @@ set (SRC
op/less.cpp
op/less_eq.cpp
op/log.cpp
op/lrn.cpp
op/max.cpp
op/maximum.cpp
op/max_pool.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/multiply.hpp"
using namespace std;
using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize)
: UnaryElementwiseArithmetic("LRN", arg)
, m_alpha(alpha)
, m_beta(beta)
, m_bias(bias)
, m_size(nsize)
{
if (arg->get_shape().size() < 3)
{
throw ngraph_error("LRN expects a tensor at least of rank of 3");
}
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<op::LRN>(new_args.at(0), m_alpha, m_beta, m_bias, m_size);
}
void op::LRN::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("NYI");
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
/// \brief Elementwise Local Response Normalization (LRN) operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------------- |
/// | `arg` | \f$N[n, c, d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[n, c, d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[n, c, d_1,\dots,d_n] = \frac{N[n,i,d_1,\dots,d_n]}{ (bias + alpha * (\sum_{i=max(0,(nsize-1)/2)}^{min(C, (nsize-1)/2)+1} N[n,i,d_1,\dots,d_n]^{2}) ^ {2})}\f$ |
class LRN : public util::UnaryElementwiseArithmetic
{
public:
/// \brief Constructs a LRN operation.
///
/// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg,
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
double get_alpha() const { return m_alpha; }
double get_beta() const { return m_beta; }
double get_bias() const { return m_bias; }
size_t get_nsize() const { return m_size; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
double m_alpha;
double m_beta;
double m_bias;
size_t m_size;
};
}
}
......@@ -52,6 +52,7 @@
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -1450,6 +1451,57 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::LRN)
{
const ngraph::op::LRN* lrn = static_cast<const ngraph::op::LRN*>(node);
writer.block_begin();
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 0);
auto output_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_data_desc =
mkldnn_emitter->build_memory_descriptor(args[0], input_format);
auto result_desc =
mkldnn_emitter->build_memory_descriptor(out[0], output_format);
auto lrn_index =
mkldnn_emitter->build_lrn_forward(input_data_desc,
result_desc,
static_cast<float>(lrn->get_alpha()),
static_cast<float>(lrn->get_beta()),
static_cast<float>(lrn->get_bias()),
static_cast<int>(lrn->get_nsize()));
auto& deps = mkldnn_emitter->get_primitive_deps(lrn_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(lrn_index) << ");\n";
}
else
{
writer << "reference::lrn<" << lrn->get_element_type().c_type_string() << ">(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " " << lrn->get_alpha() << ",\n";
writer << " " << lrn->get_beta() << ",\n";
writer << " " << lrn->get_bias() << ",\n";
writer << " " << lrn->get_nsize() << ");\n";
}
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Log)
{
......
......@@ -72,6 +72,7 @@
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -319,6 +320,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Or), &runtime::cpu::CPU_Emitter::emit<op::Or>},
{TI(ngraph::runtime::cpu::op::LoopKernel),
&runtime::cpu::CPU_Emitter::emit<runtime::cpu::op::LoopKernel>},
{TI(ngraph::op::LRN), &runtime::cpu::CPU_Emitter::emit<ngraph::op::LRN>},
};
const size_t runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction::s_memory_pool_alignment =
......@@ -419,6 +421,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/min.hpp"
......
......@@ -628,6 +628,33 @@ size_t MKLDNNEmitter::build_reorder(const mkldnn::memory::desc& input_desc,
return primitive_index;
}
size_t MKLDNNEmitter::build_lrn_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha,
float beta,
float bias,
int nsize)
{
size_t input_index = build_memory_primitive(input_desc);
size_t result_index = build_memory_primitive(result_desc);
auto lrn_desc = mkldnn::lrn_forward::desc(mkldnn::prop_kind::forward_scoring,
mkldnn::algorithm::lrn_across_channels,
input_desc,
nsize,
alpha,
beta,
bias);
auto lrn_prim_desc =
mkldnn::lrn_forward::primitive_desc(lrn_desc, mkldnn_utils::global_cpu_engine);
size_t primitive_index = insert_primitive(new mkldnn::lrn_forward(
lrn_prim_desc, *m_mkldnn_primitives[input_index], *m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc)
{
......
......@@ -351,6 +351,13 @@ namespace ngraph
size_t build_reorder(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
size_t build_lrn_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc,
float alpha,
float beta,
float bias,
int nsize);
size_t build_relu_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc);
......
......@@ -30,6 +30,7 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -485,6 +486,23 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
auto lrn = static_cast<op::LRN*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
lrn->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Sigmoid)
{
......@@ -736,6 +754,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBias>},
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ConvolutionBiasBackpropFiltersBias>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::LRN>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Relu>},
{TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
......
......@@ -32,6 +32,7 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/relu.hpp"
......@@ -1134,6 +1135,23 @@ namespace ngraph
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::LRN)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_output_formats;
prim_output_formats.push_back(input_layout);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::Sigmoid)
{
......@@ -1525,6 +1543,7 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
&runtime::cpu::pass::CPULayout::layout<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::GetOutputElement),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::GetOutputElement>},
{TI(ngraph::op::LRN), &runtime::cpu::pass::CPULayout::layout<ngraph::op::LRN>},
{TI(ngraph::op::Relu), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Relu>},
{TI(ngraph::op::Result), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Result>},
{TI(ngraph::op::ReluBackprop),
......
......@@ -12,6 +12,7 @@ divide_by_zero_int32
dot_matrix_vector_int64
#no mkldnn on GPU
#error throw is not the same on GPU, not supported yet
lrn
one_hot_scalar_fp_nonint_in_3
one_hot_scalar_oob_in_3
one_hot_vector_1_barely_oob
......
......@@ -33,6 +33,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
......@@ -78,6 +79,7 @@
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
......@@ -576,6 +578,17 @@ private:
reference::log<T>(
args[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), out[0]->get_element_count());
}
else if (node_op == "LRN")
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
lrn->get_alpha(),
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
}
else if (node_op == "Max")
{
const op::Max* max = static_cast<const op::Max*>(&node);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cmath>
#include <numeric>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void lrn(const T* arg,
T* out,
const Shape& arg_shape,
double dalpha,
double dbeta,
double dbias,
size_t size)
{
T alpha = static_cast<T>(dalpha);
T beta = static_cast<T>(dbeta);
T bias = static_cast<T>(dbias);
CoordinateTransform input_transform(arg_shape);
const size_t CHANNEL_DIM = 1;
const size_t MAX_C = arg_shape.at(CHANNEL_DIM);
for (const Coordinate& in_coord : input_transform)
{
size_t c = in_coord.at(CHANNEL_DIM);
T square_sum = 0;
for (size_t i = c; i < c + size; i++)
{
if (i < (size - 1) / 2)
continue;
if (i >= MAX_C + (size - 1) / 2)
continue;
auto sum_coord = in_coord;
sum_coord.at(CHANNEL_DIM) = i - (size - 1) / 2;
square_sum += arg[input_transform.index(sum_coord)] *
arg[input_transform.index(sum_coord)];
}
T x = arg[input_transform.index(in_coord)];
out[input_transform.index(in_coord)] =
x / (std::pow(bias + (alpha / size) * square_sum, beta));
}
}
}
}
}
......@@ -49,6 +49,7 @@
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -638,6 +639,14 @@ static shared_ptr<ngraph::Function>
{
node = make_shared<op::Log>(args[0]);
}
else if (node_op == "LRN")
{
auto alpha = node_js.at("alpha").get<double>();
auto beta = node_js.at("beta").get<double>();
auto bias = node_js.at("bias").get<double>();
auto nsize = node_js.at("nsize").get<size_t>();
node = make_shared<op::LRN>(args[0], alpha, beta, bias, nsize);
}
else if (node_op == "Max")
{
auto reduction_axes = node_js.at("reduction_axes").get<set<size_t>>();
......@@ -1141,6 +1150,14 @@ static json write(const Node& n, bool binary_constant_data)
else if (node_op == "Log")
{
}
else if (node_op == "LRN")
{
auto tmp = dynamic_cast<const op::LRN*>(&n);
node["alpha"] = tmp->get_alpha();
node["beta"] = tmp->get_beta();
node["bias"] = tmp->get_bias();
node["nsize"] = tmp->get_nsize();
}
else if (node_op == "Max")
{
auto tmp = dynamic_cast<const op::Max*>(&n);
......
......@@ -26,6 +26,7 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/serializer.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
......@@ -1464,6 +1465,37 @@ NGRAPH_TEST(${BACKEND_NAME}, log)
EXPECT_TRUE(test::all_close_f(loga, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, lrn)
{
Shape shape{2, 3, 2, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto lrn = make_shared<op::LRN>(A, 1., 2., 1., 3);
auto f = make_shared<Function>(lrn, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
vector<float> args{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, args);
auto result = backend->create_tensor(element::f32, shape);
backend->call(f, {result}, {a});
vector<float> expected{0.f,
0.05325444f,
0.03402646f,
0.01869806f,
0.06805293f,
0.03287071f,
0.00509002f,
0.00356153f,
0.00174719f,
0.0012555f,
0.00322708f,
0.00235574f};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, maximum)
{
Shape shape{2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment