Commit 97eae634 authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added operators HardSigmoid, MVN, PRelu, ScaleShift and SpaceToDepth to Python API. (#3407)

parent c101dcce
......@@ -45,6 +45,7 @@ ngraph.ops
greater
greater_eq
grn
hard_sigmoid
less
less_eq
log
......@@ -58,23 +59,27 @@ ngraph.ops
min
minimum
multiply
mvn
negative
not_equal
one_hot
pad
parameter
power
prelu
prod
relu
replace_slice
reshape
reverse
scale_shift
select
sign
sin
sinh
slice
softmax
space_to_depth
sqrt
subtract
sum
......
......@@ -58,6 +58,7 @@ from ngraph.ops import get_output_element
from ngraph.ops import greater
from ngraph.ops import greater_eq
from ngraph.ops import grn
from ngraph.ops import hard_sigmoid
from ngraph.ops import less
from ngraph.ops import less_eq
from ngraph.ops import log
......@@ -71,6 +72,7 @@ from ngraph.ops import maximum
from ngraph.ops import min
from ngraph.ops import minimum
from ngraph.ops import multiply
from ngraph.ops import mvn
from ngraph.ops import negative
from ngraph.ops import not_equal
from ngraph.ops import one_hot
......@@ -78,16 +80,19 @@ from ngraph.ops import pad
from ngraph.ops import parameter
from ngraph.ops import power
from ngraph.ops import prod
from ngraph.ops import prelu
from ngraph.ops import relu
from ngraph.ops import replace_slice
from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import scale_shift
from ngraph.ops import select
from ngraph.ops import sign
from ngraph.ops import sin
from ngraph.ops import sinh
from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import space_to_depth
from ngraph.ops import sqrt
from ngraph.ops import subtract
from ngraph.ops import sum
......
......@@ -82,6 +82,7 @@ from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
from _pyngraph.op import GRN
from _pyngraph.op import HardSigmoid
from _pyngraph.op import Less
from _pyngraph.op import LessEq
from _pyngraph.op import Log
......@@ -93,6 +94,7 @@ from _pyngraph.op import MaxPoolBackprop
from _pyngraph.op import Min
from _pyngraph.op import Minimum
from _pyngraph.op import Multiply
from _pyngraph.op import MVN
from _pyngraph.op import Negative
from _pyngraph.op import Not
from _pyngraph.op import NotEqual
......@@ -102,18 +104,21 @@ from _pyngraph.op import Or
from _pyngraph.op import Pad
from _pyngraph.op import Parameter
from _pyngraph.op import Power
from _pyngraph.op import PRelu
from _pyngraph.op import Product
from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape
from _pyngraph.op import Reverse
from _pyngraph.op import ScaleShift
from _pyngraph.op import Select
from _pyngraph.op import Sign
from _pyngraph.op import Sin
from _pyngraph.op import Sinh
from _pyngraph.op import Slice
from _pyngraph.op import Softmax
from _pyngraph.op import SpaceToDepth
from _pyngraph.op import Sqrt
from _pyngraph.op import Subtract
from _pyngraph.op import Sum
......
......@@ -23,12 +23,13 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \
LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \
OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
ScaleShift, Select, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, Subtract, Sum, Tan, \
Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Union
from typing import Callable, Iterable, List, Set, Union
from ngraph.utils.broadcasting import get_broadcast_axes
from ngraph.utils.decorators import nameable_op, binary_op, unary_op
......@@ -112,6 +113,61 @@ def grn(data, bias, name=None): # type: (Node, float, str) -> Node
return GRN(data, bias)
@nameable_op
def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str) -> Node
r"""Perform ScaleShift transformation on input node.
Computes ScaleShift:
.. math:: Y = scale\cdot data + shift
:param data: The node with data tensor.
:param scale: The node with data tensor that scale input data.
:param shift: The node with data tensor that shift input data.
:param name: Optional output node name.spa
:return: The new node performing a ScaleShift operation on input tensor.
"""
return ScaleShift(data, scale, shift)
@nameable_op
def space_to_depth(data, block_size, name=None): # type: (Node, int, str) -> Node
"""Perform SpaceToDepth operation on the input tensor.
SpaceToDepth rearranges blocks of spatial data into depth.
The operator returns a copy of the input tensor where values from the height
and width dimensions are moved to the depth dimension.
:param data: The node with data tensor.
:param block_size: The size of the block of values to be moved. Scalar value.
:param name: Optional output node name.
:return: The new node performing a SpaceToDepth operation on input tensor.
"""
return SpaceToDepth(data, block_size)
@nameable_op
def mvn(data, axes, normalize_variance, eps, name=None):
# type: (Node, Set[int], bool, float, str) -> Node
r"""Perform Mean Variance Normalization operation on data from input node.
Computes MVN on the input tensor :code:`data` (called `X`) using formula:
.. math:: Y = \dfrac{X-EX}{\sqrt{E(X-EX)^2}}
:param data: The node with data tensor.
:param axes: A list of axes, along which to reduce. Array of integers.
:param normalize_variance: Flag that denotes if mean values are shared across channels.
Boolen value.
:param eps: The number added to the variance to avoid division by zero
when normalizing the value. Scalar value.
:param name: Optional output node name.
:return: The new node performing a MVN operation on input tensor.
"""
return MVN(data, AxisSet(axes), normalize_variance, eps)
# Unary ops
@unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node
......@@ -929,6 +985,46 @@ def min(node, reduction_axes=None, name=None):
return Min(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
def prelu(data, slope, name=None): # type: (Node, Node, str) -> Node
"""Perform Parametrized Relu operation element-wise on data from input node.
PRelu uses the following logic:
.. code-block:: python
if data < 0:
data = data * slope
elif data >= 0:
data = data
:param data: The node with data tensor.
:param slope: The node with the multipliers for negative values.
:param name: Optional output node name.
:return: The new node performing a PRelu operation on tensor's channels.
"""
return PRelu(data, slope)
@nameable_op
def hard_sigmoid(data, alpha, beta, name=None): # type: (Node, float, float, str) -> Node
"""Perform Hard Sigmoid operation element-wise on data from input node.
Hard Sigmoid uses the following logic:
.. code-block:: python
y = max(0, min(1, alpha * data + beta))
:param data: The node with data tensor.
:param alpha: Alpha parameter. Scalar value.
:param beta: Beta parameter. Scalar value.
:param name: Optional output node name.
:return: The new node performing a Hard Sigmoid element-wise on input tensor.
"""
return HardSigmoid(data, alpha, beta)
@nameable_op
def prod(node, reduction_axes=None, name=None):
# type: (Node, Iterable[int], str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "pyngraph/ops/fused/hard_sigmoid.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_HardSigmoid(py::module m)
{
py::class_<ngraph::op::HardSigmoid, std::shared_ptr<ngraph::op::HardSigmoid>, ngraph::op::Op>
hardsigmoid(m, "HardSigmoid");
hardsigmoid.doc() = "ngraph.impl.op.HardSigmoid wraps ngraph::op::HardSigmoid";
hardsigmoid.def(py::init<const std::shared_ptr<ngraph::Node>&, float&, float&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_HardSigmoid(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/mvn.hpp"
#include "pyngraph/ops/fused/mvn.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_MVN(py::module m)
{
py::class_<ngraph::op::MVN, std::shared_ptr<ngraph::op::MVN>, ngraph::op::Op> mvn(m, "MVN");
mvn.doc() = "ngraph.impl.op.MVN wraps ngraph::op::MVN";
mvn.def(
py::init<const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&, bool&, double&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_MVN(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/prelu.hpp"
#include "pyngraph/ops/fused/prelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_PRelu(py::module m)
{
py::class_<ngraph::op::PRelu, std::shared_ptr<ngraph::op::PRelu>, ngraph::op::Op> prelu(
m, "PRelu");
prelu.doc() = "ngraph.impl.op.PRelu wraps ngraph::op::PRelu";
prelu.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_PRelu(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ScaleShift(py::module m)
{
py::class_<ngraph::op::ScaleShift, std::shared_ptr<ngraph::op::ScaleShift>, ngraph::op::Op>
scaleshift(m, "ScaleShift");
scaleshift.doc() = "ngraph.impl.op.ScaleShift wraps ngraph::op::ScaleShift";
scaleshift.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ScaleShift(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_SpaceToDepth(py::module m)
{
py::class_<ngraph::op::SpaceToDepth, std::shared_ptr<ngraph::op::SpaceToDepth>, ngraph::op::Op>
spacetodepth(m, "SpaceToDepth");
spacetodepth.doc() = "ngraph.impl.op.SpaceToDepth wraps ngraph::op::SpaceToDepth";
spacetodepth.def(py::init<const std::shared_ptr<ngraph::Node>&, int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_SpaceToDepth(py::module m);
......@@ -62,6 +62,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op);
regclass_pyngraph_op_GRN(m_op);
regclass_pyngraph_op_HardSigmoid(m_op);
regclass_pyngraph_op_Less(m_op);
regclass_pyngraph_op_LessEq(m_op);
regclass_pyngraph_op_Log(m_op);
......@@ -73,6 +74,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Min(m_op);
regclass_pyngraph_op_Minimum(m_op);
regclass_pyngraph_op_Multiply(m_op);
regclass_pyngraph_op_MVN(m_op);
regclass_pyngraph_op_Negative(m_op);
regclass_pyngraph_op_Not(m_op);
regclass_pyngraph_op_NotEqual(m_op);
......@@ -82,18 +84,21 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Parameter(m_op);
regclass_pyngraph_op_Passthrough(m_op);
regclass_pyngraph_op_Power(m_op);
regclass_pyngraph_op_PRelu(m_op);
regclass_pyngraph_op_Product(m_op);
regclass_pyngraph_op_Relu(m_op);
regclass_pyngraph_op_ReluBackprop(m_op);
regclass_pyngraph_op_ReplaceSlice(m_op);
regclass_pyngraph_op_Reshape(m_op);
regclass_pyngraph_op_Reverse(m_op);
regclass_pyngraph_op_ScaleShift(m_op);
regclass_pyngraph_op_Select(m_op);
regclass_pyngraph_op_Sign(m_op);
regclass_pyngraph_op_Sin(m_op);
regclass_pyngraph_op_Sinh(m_op);
regclass_pyngraph_op_Slice(m_op);
regclass_pyngraph_op_Softmax(m_op);
regclass_pyngraph_op_SpaceToDepth(m_op);
regclass_pyngraph_op_Sqrt(m_op);
regclass_pyngraph_op_Subtract(m_op);
regclass_pyngraph_op_Sum(m_op);
......
......@@ -49,6 +49,11 @@
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/fused/hard_sigmoid.hpp"
#include "pyngraph/ops/fused/mvn.hpp"
#include "pyngraph/ops/fused/prelu.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
......
......@@ -191,6 +191,7 @@ sources = [
'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/fused/grn.cpp',
'pyngraph/ops/fused/hard_sigmoid.cpp',
'pyngraph/ops/less.cpp',
'pyngraph/ops/less_eq.cpp',
'pyngraph/ops/log.cpp',
......@@ -201,6 +202,7 @@ sources = [
'pyngraph/ops/max_pool.cpp',
'pyngraph/ops/minimum.cpp',
'pyngraph/ops/multiply.cpp',
'pyngraph/ops/fused/mvn.cpp',
'pyngraph/ops/negative.cpp',
'pyngraph/ops/not.cpp',
'pyngraph/ops/not_equal.cpp',
......@@ -211,16 +213,19 @@ sources = [
'pyngraph/ops/parameter.cpp',
'pyngraph/ops/passthrough.cpp',
'pyngraph/ops/power.cpp',
'pyngraph/ops/fused/prelu.cpp',
'pyngraph/ops/regmodule_pyngraph_op.cpp',
'pyngraph/ops/relu.cpp',
'pyngraph/ops/replace_slice.cpp',
'pyngraph/ops/reshape.cpp',
'pyngraph/ops/reverse.cpp',
'pyngraph/ops/fused/scale_shift.cpp',
'pyngraph/ops/select.cpp',
'pyngraph/ops/sign.cpp',
'pyngraph/ops/sin.cpp',
'pyngraph/ops/sinh.cpp',
'pyngraph/ops/slice.cpp',
'pyngraph/ops/fused/space_to_depth.cpp',
'pyngraph/ops/sqrt.cpp',
'pyngraph/ops/subtract.cpp',
'pyngraph/ops/sum.cpp',
......
......@@ -282,3 +282,120 @@ def test_grn_operator():
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_prelu_operator():
runtime = get_runtime()
data_shape = [1, 2, 3, 4]
slope_shape = [2, 3, 1]
data_value = np.arange(start=1.0, stop=25.0, dtype=np.float32).reshape(data_shape)
slope_value = np.arange(start=-10.0, stop=-4.0, dtype=np.float32).reshape(slope_shape)\
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_slope = ng.parameter(slope_shape, name='Slope', dtype=np.float32)
model = ng.prelu(parameter_data, parameter_slope)
computation = runtime.computation(model, parameter_data, parameter_slope)
result = computation(data_value, slope_value)
expected = np.clip(data_value, 0, np.inf) + np.clip(data_value, -np.inf, 0) * slope_value
assert np.allclose(result, expected)
def test_hard_sigmoid_operator():
runtime = get_runtime()
data_shape = [3]
alpha = np.float32(0.5)
beta = np.float32(0.6)
data_value = np.array([-1, 0, 1], dtype=np.float32)
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.hard_sigmoid(parameter_data, alpha, beta)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = [0.1, 0.6, 1.]
assert np.allclose(result, expected)
def test_mvn_operator():
runtime = get_runtime()
data_shape = [3, 3, 3, 1]
axis = [0, 2, 3]
normalize_variance = True
eps = np.float32(1e-9)
data_value = np.array([[[[0.8439683], [0.5665144], [0.05836735]],
[[0.02916367], [0.12964272], [0.5060197]],
[[0.79538304], [0.9411346], [0.9546573]]],
[[[0.17730942], [0.46192095], [0.26480448]],
[[0.6746842], [0.01665257], [0.62473077]],
[[0.9240844], [0.9722341], [0.11965699]]],
[[[0.41356155], [0.9129373], [0.59330076]],
[[0.81929934], [0.7862604], [0.11799799]],
[[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32)
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.mvn(parameter_data, axis, normalize_variance, eps)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
data_mean = np.mean(data_value, axis=(0, 2, 3), keepdims=1)
data_mean_squared = np.power(data_mean, 2)
data_squared = np.power(data_value, 2)
data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1)
std = np.sqrt(data_squared_mean - data_mean_squared)
expected = (data_value - data_mean) / (std + 1e-9)
assert np.allclose(result, expected)
def test_scale_shift_operator():
runtime = get_runtime()
data_shape = [3, 6]
scale_shape = [3, 6]
shift_shape = [1]
data_value = np.arange(start=19.0, stop=1.0, step=-1.0, dtype=np.float32).reshape(data_shape)
scale_value = np.arange(start=19.0, stop=1.0, step=-1.0, dtype=np.float32).reshape(scale_shape)
shift_value = [2.0]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_scale = ng.parameter(scale_shape, name='Scale', dtype=np.float32)
parameter_shift = ng.parameter(shift_shape, name='Shift', dtype=np.float32)
model = ng.scale_shift(parameter_data, parameter_scale, parameter_shift)
computation = runtime.computation(model, parameter_data, parameter_scale, parameter_shift)
result = computation(data_value, scale_value, shift_value)
expected = np.add(np.multiply(data_value, scale_value), shift_value)
assert np.allclose(result, expected)
def test_space_to_depth_operator():
runtime = get_runtime()
data_shape = [1, 2, 4, 4]
data_value = np.arange(start=0, stop=32, step=1.0, dtype=np.float32).reshape(data_shape)
block_size = 2
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.space_to_depth(parameter_data, block_size)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([0, 2, 8, 10, 16, 18, 24, 26,
1, 3, 9, 11, 17, 19, 25, 27,
4, 6, 12, 14, 20, 22, 28, 30,
5, 7, 13, 15, 21, 23, 29, 31], dtype=np.float32).reshape(1, 8, 2, 2)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment