Commit 70738769 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Py] Wrapper for LRN (#1313)

* [Py] Wrapper for LRN

* Add missing header

* Add default param values, docs and some unit tests

* Fixes

* clang-format
parent 48da82cb
...@@ -49,6 +49,7 @@ from ngraph.ops import log ...@@ -49,6 +49,7 @@ from ngraph.ops import log
from ngraph.ops import logical_and from ngraph.ops import logical_and
from ngraph.ops import logical_or from ngraph.ops import logical_or
from ngraph.ops import logical_not from ngraph.ops import logical_not
from ngraph.ops import lrn
from ngraph.ops import max from ngraph.ops import max
from ngraph.ops import max_pool from ngraph.ops import max_pool
from ngraph.ops import maximum from ngraph.ops import maximum
......
...@@ -76,6 +76,7 @@ from _pyngraph.op import GreaterEq ...@@ -76,6 +76,7 @@ from _pyngraph.op import GreaterEq
from _pyngraph.op import Less from _pyngraph.op import Less
from _pyngraph.op import LessEq from _pyngraph.op import LessEq
from _pyngraph.op import Log from _pyngraph.op import Log
from _pyngraph.op import LRN
from _pyngraph.op import Max from _pyngraph.op import Max
from _pyngraph.op import Maximum from _pyngraph.op import Maximum
from _pyngraph.op import MaxPool from _pyngraph.op import MaxPool
......
...@@ -23,9 +23,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,9 +23,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \
Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, \ Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, \
Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, \ Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, \
Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \ Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, \
Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, \ Pad, Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, \
Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh Sin, Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
...@@ -929,6 +929,28 @@ def batch_norm(eps, # type: float ...@@ -929,6 +929,28 @@ def batch_norm(eps, # type: float
return BatchNorm(eps, gamma, beta, data, mean, variance, training) return BatchNorm(eps, gamma, beta, data, mean, variance, training)
@nameable_op
def lrn(data, # type: Node
alpha=1, # type: float
beta=0.5, # type: float
bias=1, # type: float
size=5, # type: int
name=None, # type: str
):
# type: (...) -> Node
"""Return a node which performs element-wise Local Response Normalization (LRN) operation.
:param data: Input data.
:param alpha: A scale factor (usually positive).
:param beta: An exponent.
:param bias: An offset (usually positive) to avoid dividing by 0.
:param size: Width of the 1-D normalization window.
:param name: An optional name of the output node.
:return: The new node which performs LRN.
"""
return LRN(data, alpha, beta, bias, size)
@nameable_op @nameable_op
def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node def function_call(function_to_call, args): # type: (Node, NodeVector) -> Node
"""Return Function call op.""" """Return Function call op."""
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/lrn.hpp"
#include "pyngraph/ops/lrn.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_LRN(py::module m)
{
py::class_<ngraph::op::LRN,
std::shared_ptr<ngraph::op::LRN>,
ngraph::op::util::RequiresTensorViewArgs>
lrn(m, "LRN");
lrn.doc() = "ngraph.impl.op.LRN wraps ngraph::op::LRN";
lrn.def(py::init<const std::shared_ptr<ngraph::Node>&, double&, double&, double&, size_t&>());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_LRN(py::module m);
...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Less(m_op); regclass_pyngraph_op_Less(m_op);
regclass_pyngraph_op_LessEq(m_op); regclass_pyngraph_op_LessEq(m_op);
regclass_pyngraph_op_Log(m_op); regclass_pyngraph_op_Log(m_op);
regclass_pyngraph_op_LRN(m_op);
regclass_pyngraph_op_MaxPool(m_op); regclass_pyngraph_op_MaxPool(m_op);
regclass_pyngraph_op_MaxPoolBackprop(m_op); regclass_pyngraph_op_MaxPoolBackprop(m_op);
regclass_pyngraph_op_Maximum(m_op); regclass_pyngraph_op_Maximum(m_op);
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "pyngraph/ops/less.hpp" #include "pyngraph/ops/less.hpp"
#include "pyngraph/ops/less_eq.hpp" #include "pyngraph/ops/less_eq.hpp"
#include "pyngraph/ops/log.hpp" #include "pyngraph/ops/log.hpp"
#include "pyngraph/ops/lrn.hpp"
#include "pyngraph/ops/max_pool.hpp" #include "pyngraph/ops/max_pool.hpp"
#include "pyngraph/ops/maximum.hpp" #include "pyngraph/ops/maximum.hpp"
#include "pyngraph/ops/minimum.hpp" #include "pyngraph/ops/minimum.hpp"
......
...@@ -169,6 +169,7 @@ sources = ['pyngraph/function.cpp', ...@@ -169,6 +169,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
'pyngraph/ops/less_eq.cpp', 'pyngraph/ops/less_eq.cpp',
'pyngraph/ops/log.cpp', 'pyngraph/ops/log.cpp',
'pyngraph/ops/lrn.cpp',
'pyngraph/ops/maximum.cpp', 'pyngraph/ops/maximum.cpp',
'pyngraph/ops/max.cpp', 'pyngraph/ops/max.cpp',
'pyngraph/ops/product.cpp', 'pyngraph/ops/product.cpp',
......
# ******************************************************************************
# Copyright 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
import ngraph as ng
from test.ngraph.util import get_runtime
@pytest.config.gpu_skip(reason='Not implemented')
def test_lrn():
input_image_shape = (2, 3, 2, 1)
input_image = np.arange(int(np.prod(input_image_shape))).reshape(input_image_shape).astype('f')
runtime = get_runtime()
model = ng.lrn(ng.constant(input_image), alpha=1.0, beta=2.0, bias=1.0, size=3)
computation = runtime.computation(model)
result = computation()
assert np.allclose(result,
np.array([[[[0.0],
[0.05325444]],
[[0.03402646],
[0.01869806]],
[[0.06805293],
[0.03287071]]],
[[[0.00509002],
[0.00356153]],
[[0.00174719],
[0.0012555]],
[[0.00322708],
[0.00235574]]]], dtype=np.float32))
# Test LRN default parameter values
model = ng.lrn(ng.constant(input_image))
computation = runtime.computation(model)
result = computation()
assert np.allclose(result,
np.array([[[[0.0],
[0.35355338]],
[[0.8944272],
[1.0606602]],
[[1.7888544],
[1.767767]]],
[[[0.93704253],
[0.97827977]],
[[1.2493901],
[1.2577883]],
[[1.5617375],
[1.5372968]]]], dtype=np.float32))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment