Commit 97eae634 authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added operators HardSigmoid, MVN, PRelu, ScaleShift and SpaceToDepth to Python API. (#3407)

parent c101dcce
...@@ -45,6 +45,7 @@ ngraph.ops ...@@ -45,6 +45,7 @@ ngraph.ops
greater greater
greater_eq greater_eq
grn grn
hard_sigmoid
less less
less_eq less_eq
log log
...@@ -58,23 +59,27 @@ ngraph.ops ...@@ -58,23 +59,27 @@ ngraph.ops
min min
minimum minimum
multiply multiply
mvn
negative negative
not_equal not_equal
one_hot one_hot
pad pad
parameter parameter
power power
prelu
prod prod
relu relu
replace_slice replace_slice
reshape reshape
reverse reverse
scale_shift
select select
sign sign
sin sin
sinh sinh
slice slice
softmax softmax
space_to_depth
sqrt sqrt
subtract subtract
sum sum
......
...@@ -58,6 +58,7 @@ from ngraph.ops import get_output_element ...@@ -58,6 +58,7 @@ from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
from ngraph.ops import grn from ngraph.ops import grn
from ngraph.ops import hard_sigmoid
from ngraph.ops import less from ngraph.ops import less
from ngraph.ops import less_eq from ngraph.ops import less_eq
from ngraph.ops import log from ngraph.ops import log
...@@ -71,6 +72,7 @@ from ngraph.ops import maximum ...@@ -71,6 +72,7 @@ from ngraph.ops import maximum
from ngraph.ops import min from ngraph.ops import min
from ngraph.ops import minimum from ngraph.ops import minimum
from ngraph.ops import multiply from ngraph.ops import multiply
from ngraph.ops import mvn
from ngraph.ops import negative from ngraph.ops import negative
from ngraph.ops import not_equal from ngraph.ops import not_equal
from ngraph.ops import one_hot from ngraph.ops import one_hot
...@@ -78,16 +80,19 @@ from ngraph.ops import pad ...@@ -78,16 +80,19 @@ from ngraph.ops import pad
from ngraph.ops import parameter from ngraph.ops import parameter
from ngraph.ops import power from ngraph.ops import power
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import prelu
from ngraph.ops import relu from ngraph.ops import relu
from ngraph.ops import replace_slice from ngraph.ops import replace_slice
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import reverse from ngraph.ops import reverse
from ngraph.ops import scale_shift
from ngraph.ops import select from ngraph.ops import select
from ngraph.ops import sign from ngraph.ops import sign
from ngraph.ops import sin from ngraph.ops import sin
from ngraph.ops import sinh from ngraph.ops import sinh
from ngraph.ops import slice from ngraph.ops import slice
from ngraph.ops import softmax from ngraph.ops import softmax
from ngraph.ops import space_to_depth
from ngraph.ops import sqrt from ngraph.ops import sqrt
from ngraph.ops import subtract from ngraph.ops import subtract
from ngraph.ops import sum from ngraph.ops import sum
......
...@@ -82,6 +82,7 @@ from _pyngraph.op import GetOutputElement ...@@ -82,6 +82,7 @@ from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
from _pyngraph.op import GRN from _pyngraph.op import GRN
from _pyngraph.op import HardSigmoid
from _pyngraph.op import Less from _pyngraph.op import Less
from _pyngraph.op import LessEq from _pyngraph.op import LessEq
from _pyngraph.op import Log from _pyngraph.op import Log
...@@ -93,6 +94,7 @@ from _pyngraph.op import MaxPoolBackprop ...@@ -93,6 +94,7 @@ from _pyngraph.op import MaxPoolBackprop
from _pyngraph.op import Min from _pyngraph.op import Min
from _pyngraph.op import Minimum from _pyngraph.op import Minimum
from _pyngraph.op import Multiply from _pyngraph.op import Multiply
from _pyngraph.op import MVN
from _pyngraph.op import Negative from _pyngraph.op import Negative
from _pyngraph.op import Not from _pyngraph.op import Not
from _pyngraph.op import NotEqual from _pyngraph.op import NotEqual
...@@ -102,18 +104,21 @@ from _pyngraph.op import Or ...@@ -102,18 +104,21 @@ from _pyngraph.op import Or
from _pyngraph.op import Pad from _pyngraph.op import Pad
from _pyngraph.op import Parameter from _pyngraph.op import Parameter
from _pyngraph.op import Power from _pyngraph.op import Power
from _pyngraph.op import PRelu
from _pyngraph.op import Product from _pyngraph.op import Product
from _pyngraph.op import Relu from _pyngraph.op import Relu
from _pyngraph.op import ReluBackprop from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape from _pyngraph.op import Reshape
from _pyngraph.op import Reverse from _pyngraph.op import Reverse
from _pyngraph.op import ScaleShift
from _pyngraph.op import Select from _pyngraph.op import Select
from _pyngraph.op import Sign from _pyngraph.op import Sign
from _pyngraph.op import Sin from _pyngraph.op import Sin
from _pyngraph.op import Sinh from _pyngraph.op import Sinh
from _pyngraph.op import Slice from _pyngraph.op import Slice
from _pyngraph.op import Softmax from _pyngraph.op import Softmax
from _pyngraph.op import SpaceToDepth
from _pyngraph.op import Sqrt from _pyngraph.op import Sqrt
from _pyngraph.op import Subtract from _pyngraph.op import Subtract
from _pyngraph.op import Sum from _pyngraph.op import Sum
......
...@@ -23,12 +23,13 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -23,12 +23,13 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \ Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \ Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \ LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \ OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze ScaleShift, Select, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, Subtract, Sum, Tan, \
Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Set, Union
from ngraph.utils.broadcasting import get_broadcast_axes from ngraph.utils.broadcasting import get_broadcast_axes
from ngraph.utils.decorators import nameable_op, binary_op, unary_op from ngraph.utils.decorators import nameable_op, binary_op, unary_op
...@@ -112,6 +113,61 @@ def grn(data, bias, name=None): # type: (Node, float, str) -> Node ...@@ -112,6 +113,61 @@ def grn(data, bias, name=None): # type: (Node, float, str) -> Node
return GRN(data, bias) return GRN(data, bias)
@nameable_op
def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str) -> Node
r"""Perform ScaleShift transformation on input node.
Computes ScaleShift:
.. math:: Y = scale\cdot data + shift
:param data: The node with data tensor.
:param scale: The node with data tensor that scale input data.
:param shift: The node with data tensor that shift input data.
:param name: Optional output node name.spa
:return: The new node performing a ScaleShift operation on input tensor.
"""
return ScaleShift(data, scale, shift)
@nameable_op
def space_to_depth(data, block_size, name=None): # type: (Node, int, str) -> Node
"""Perform SpaceToDepth operation on the input tensor.
SpaceToDepth rearranges blocks of spatial data into depth.
The operator returns a copy of the input tensor where values from the height
and width dimensions are moved to the depth dimension.
:param data: The node with data tensor.
:param block_size: The size of the block of values to be moved. Scalar value.
:param name: Optional output node name.
:return: The new node performing a SpaceToDepth operation on input tensor.
"""
return SpaceToDepth(data, block_size)
@nameable_op
def mvn(data, axes, normalize_variance, eps, name=None):
# type: (Node, Set[int], bool, float, str) -> Node
r"""Perform Mean Variance Normalization operation on data from input node.
Computes MVN on the input tensor :code:`data` (called `X`) using formula:
.. math:: Y = \dfrac{X-EX}{\sqrt{E(X-EX)^2}}
:param data: The node with data tensor.
:param axes: A list of axes, along which to reduce. Array of integers.
:param normalize_variance: Flag that denotes if mean values are shared across channels.
Boolen value.
:param eps: The number added to the variance to avoid division by zero
when normalizing the value. Scalar value.
:param name: Optional output node name.
:return: The new node performing a MVN operation on input tensor.
"""
return MVN(data, AxisSet(axes), normalize_variance, eps)
# Unary ops # Unary ops
@unary_op @unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node def absolute(node, name=None): # type: (NodeInput, str) -> Node
...@@ -929,6 +985,46 @@ def min(node, reduction_axes=None, name=None): ...@@ -929,6 +985,46 @@ def min(node, reduction_axes=None, name=None):
return Min(node, AxisSet(get_reduction_axes(node, reduction_axes))) return Min(node, AxisSet(get_reduction_axes(node, reduction_axes)))
@nameable_op
def prelu(data, slope, name=None): # type: (Node, Node, str) -> Node
"""Perform Parametrized Relu operation element-wise on data from input node.
PRelu uses the following logic:
.. code-block:: python
if data < 0:
data = data * slope
elif data >= 0:
data = data
:param data: The node with data tensor.
:param slope: The node with the multipliers for negative values.
:param name: Optional output node name.
:return: The new node performing a PRelu operation on tensor's channels.
"""
return PRelu(data, slope)
@nameable_op
def hard_sigmoid(data, alpha, beta, name=None): # type: (Node, float, float, str) -> Node
"""Perform Hard Sigmoid operation element-wise on data from input node.
Hard Sigmoid uses the following logic:
.. code-block:: python
y = max(0, min(1, alpha * data + beta))
:param data: The node with data tensor.
:param alpha: Alpha parameter. Scalar value.
:param beta: Beta parameter. Scalar value.
:param name: Optional output node name.
:return: The new node performing a Hard Sigmoid element-wise on input tensor.
"""
return HardSigmoid(data, alpha, beta)
@nameable_op @nameable_op
def prod(node, reduction_axes=None, name=None): def prod(node, reduction_axes=None, name=None):
# type: (Node, Iterable[int], str) -> Node # type: (Node, Iterable[int], str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/hard_sigmoid.hpp"
#include "pyngraph/ops/fused/hard_sigmoid.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_HardSigmoid(py::module m)
{
py::class_<ngraph::op::HardSigmoid, std::shared_ptr<ngraph::op::HardSigmoid>, ngraph::op::Op>
hardsigmoid(m, "HardSigmoid");
hardsigmoid.doc() = "ngraph.impl.op.HardSigmoid wraps ngraph::op::HardSigmoid";
hardsigmoid.def(py::init<const std::shared_ptr<ngraph::Node>&, float&, float&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_HardSigmoid(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/mvn.hpp"
#include "pyngraph/ops/fused/mvn.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_MVN(py::module m)
{
py::class_<ngraph::op::MVN, std::shared_ptr<ngraph::op::MVN>, ngraph::op::Op> mvn(m, "MVN");
mvn.doc() = "ngraph.impl.op.MVN wraps ngraph::op::MVN";
mvn.def(
py::init<const std::shared_ptr<ngraph::Node>&, const ngraph::AxisSet&, bool&, double&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_MVN(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/prelu.hpp"
#include "pyngraph/ops/fused/prelu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_PRelu(py::module m)
{
py::class_<ngraph::op::PRelu, std::shared_ptr<ngraph::op::PRelu>, ngraph::op::Op> prelu(
m, "PRelu");
prelu.doc() = "ngraph.impl.op.PRelu wraps ngraph::op::PRelu";
prelu.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_PRelu(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ScaleShift(py::module m)
{
py::class_<ngraph::op::ScaleShift, std::shared_ptr<ngraph::op::ScaleShift>, ngraph::op::Op>
scaleshift(m, "ScaleShift");
scaleshift.doc() = "ngraph.impl.op.ScaleShift wraps ngraph::op::ScaleShift";
scaleshift.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ScaleShift(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_SpaceToDepth(py::module m)
{
py::class_<ngraph::op::SpaceToDepth, std::shared_ptr<ngraph::op::SpaceToDepth>, ngraph::op::Op>
spacetodepth(m, "SpaceToDepth");
spacetodepth.doc() = "ngraph.impl.op.SpaceToDepth wraps ngraph::op::SpaceToDepth";
spacetodepth.def(py::init<const std::shared_ptr<ngraph::Node>&, int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_SpaceToDepth(py::module m);
...@@ -62,6 +62,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -62,6 +62,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
regclass_pyngraph_op_GRN(m_op); regclass_pyngraph_op_GRN(m_op);
regclass_pyngraph_op_HardSigmoid(m_op);
regclass_pyngraph_op_Less(m_op); regclass_pyngraph_op_Less(m_op);
regclass_pyngraph_op_LessEq(m_op); regclass_pyngraph_op_LessEq(m_op);
regclass_pyngraph_op_Log(m_op); regclass_pyngraph_op_Log(m_op);
...@@ -73,6 +74,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -73,6 +74,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Min(m_op); regclass_pyngraph_op_Min(m_op);
regclass_pyngraph_op_Minimum(m_op); regclass_pyngraph_op_Minimum(m_op);
regclass_pyngraph_op_Multiply(m_op); regclass_pyngraph_op_Multiply(m_op);
regclass_pyngraph_op_MVN(m_op);
regclass_pyngraph_op_Negative(m_op); regclass_pyngraph_op_Negative(m_op);
regclass_pyngraph_op_Not(m_op); regclass_pyngraph_op_Not(m_op);
regclass_pyngraph_op_NotEqual(m_op); regclass_pyngraph_op_NotEqual(m_op);
...@@ -82,18 +84,21 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -82,18 +84,21 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Parameter(m_op); regclass_pyngraph_op_Parameter(m_op);
regclass_pyngraph_op_Passthrough(m_op); regclass_pyngraph_op_Passthrough(m_op);
regclass_pyngraph_op_Power(m_op); regclass_pyngraph_op_Power(m_op);
regclass_pyngraph_op_PRelu(m_op);
regclass_pyngraph_op_Product(m_op); regclass_pyngraph_op_Product(m_op);
regclass_pyngraph_op_Relu(m_op); regclass_pyngraph_op_Relu(m_op);
regclass_pyngraph_op_ReluBackprop(m_op); regclass_pyngraph_op_ReluBackprop(m_op);
regclass_pyngraph_op_ReplaceSlice(m_op); regclass_pyngraph_op_ReplaceSlice(m_op);
regclass_pyngraph_op_Reshape(m_op); regclass_pyngraph_op_Reshape(m_op);
regclass_pyngraph_op_Reverse(m_op); regclass_pyngraph_op_Reverse(m_op);
regclass_pyngraph_op_ScaleShift(m_op);
regclass_pyngraph_op_Select(m_op); regclass_pyngraph_op_Select(m_op);
regclass_pyngraph_op_Sign(m_op); regclass_pyngraph_op_Sign(m_op);
regclass_pyngraph_op_Sin(m_op); regclass_pyngraph_op_Sin(m_op);
regclass_pyngraph_op_Sinh(m_op); regclass_pyngraph_op_Sinh(m_op);
regclass_pyngraph_op_Slice(m_op); regclass_pyngraph_op_Slice(m_op);
regclass_pyngraph_op_Softmax(m_op); regclass_pyngraph_op_Softmax(m_op);
regclass_pyngraph_op_SpaceToDepth(m_op);
regclass_pyngraph_op_Sqrt(m_op); regclass_pyngraph_op_Sqrt(m_op);
regclass_pyngraph_op_Subtract(m_op); regclass_pyngraph_op_Subtract(m_op);
regclass_pyngraph_op_Sum(m_op); regclass_pyngraph_op_Sum(m_op);
......
...@@ -49,6 +49,11 @@ ...@@ -49,6 +49,11 @@
#include "pyngraph/ops/fused/gelu.hpp" #include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp" #include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp" #include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/fused/hard_sigmoid.hpp"
#include "pyngraph/ops/fused/mvn.hpp"
#include "pyngraph/ops/fused/prelu.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp" #include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
......
...@@ -191,6 +191,7 @@ sources = [ ...@@ -191,6 +191,7 @@ sources = [
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/fused/grn.cpp', 'pyngraph/ops/fused/grn.cpp',
'pyngraph/ops/fused/hard_sigmoid.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
'pyngraph/ops/less_eq.cpp', 'pyngraph/ops/less_eq.cpp',
'pyngraph/ops/log.cpp', 'pyngraph/ops/log.cpp',
...@@ -201,6 +202,7 @@ sources = [ ...@@ -201,6 +202,7 @@ sources = [
'pyngraph/ops/max_pool.cpp', 'pyngraph/ops/max_pool.cpp',
'pyngraph/ops/minimum.cpp', 'pyngraph/ops/minimum.cpp',
'pyngraph/ops/multiply.cpp', 'pyngraph/ops/multiply.cpp',
'pyngraph/ops/fused/mvn.cpp',
'pyngraph/ops/negative.cpp', 'pyngraph/ops/negative.cpp',
'pyngraph/ops/not.cpp', 'pyngraph/ops/not.cpp',
'pyngraph/ops/not_equal.cpp', 'pyngraph/ops/not_equal.cpp',
...@@ -211,16 +213,19 @@ sources = [ ...@@ -211,16 +213,19 @@ sources = [
'pyngraph/ops/parameter.cpp', 'pyngraph/ops/parameter.cpp',
'pyngraph/ops/passthrough.cpp', 'pyngraph/ops/passthrough.cpp',
'pyngraph/ops/power.cpp', 'pyngraph/ops/power.cpp',
'pyngraph/ops/fused/prelu.cpp',
'pyngraph/ops/regmodule_pyngraph_op.cpp', 'pyngraph/ops/regmodule_pyngraph_op.cpp',
'pyngraph/ops/relu.cpp', 'pyngraph/ops/relu.cpp',
'pyngraph/ops/replace_slice.cpp', 'pyngraph/ops/replace_slice.cpp',
'pyngraph/ops/reshape.cpp', 'pyngraph/ops/reshape.cpp',
'pyngraph/ops/reverse.cpp', 'pyngraph/ops/reverse.cpp',
'pyngraph/ops/fused/scale_shift.cpp',
'pyngraph/ops/select.cpp', 'pyngraph/ops/select.cpp',
'pyngraph/ops/sign.cpp', 'pyngraph/ops/sign.cpp',
'pyngraph/ops/sin.cpp', 'pyngraph/ops/sin.cpp',
'pyngraph/ops/sinh.cpp', 'pyngraph/ops/sinh.cpp',
'pyngraph/ops/slice.cpp', 'pyngraph/ops/slice.cpp',
'pyngraph/ops/fused/space_to_depth.cpp',
'pyngraph/ops/sqrt.cpp', 'pyngraph/ops/sqrt.cpp',
'pyngraph/ops/subtract.cpp', 'pyngraph/ops/subtract.cpp',
'pyngraph/ops/sum.cpp', 'pyngraph/ops/sum.cpp',
......
...@@ -282,3 +282,120 @@ def test_grn_operator(): ...@@ -282,3 +282,120 @@ def test_grn_operator():
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32) [0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_prelu_operator():
runtime = get_runtime()
data_shape = [1, 2, 3, 4]
slope_shape = [2, 3, 1]
data_value = np.arange(start=1.0, stop=25.0, dtype=np.float32).reshape(data_shape)
slope_value = np.arange(start=-10.0, stop=-4.0, dtype=np.float32).reshape(slope_shape)\
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_slope = ng.parameter(slope_shape, name='Slope', dtype=np.float32)
model = ng.prelu(parameter_data, parameter_slope)
computation = runtime.computation(model, parameter_data, parameter_slope)
result = computation(data_value, slope_value)
expected = np.clip(data_value, 0, np.inf) + np.clip(data_value, -np.inf, 0) * slope_value
assert np.allclose(result, expected)
def test_hard_sigmoid_operator():
runtime = get_runtime()
data_shape = [3]
alpha = np.float32(0.5)
beta = np.float32(0.6)
data_value = np.array([-1, 0, 1], dtype=np.float32)
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.hard_sigmoid(parameter_data, alpha, beta)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = [0.1, 0.6, 1.]
assert np.allclose(result, expected)
def test_mvn_operator():
runtime = get_runtime()
data_shape = [3, 3, 3, 1]
axis = [0, 2, 3]
normalize_variance = True
eps = np.float32(1e-9)
data_value = np.array([[[[0.8439683], [0.5665144], [0.05836735]],
[[0.02916367], [0.12964272], [0.5060197]],
[[0.79538304], [0.9411346], [0.9546573]]],
[[[0.17730942], [0.46192095], [0.26480448]],
[[0.6746842], [0.01665257], [0.62473077]],
[[0.9240844], [0.9722341], [0.11965699]]],
[[[0.41356155], [0.9129373], [0.59330076]],
[[0.81929934], [0.7862604], [0.11799799]],
[[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32)
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.mvn(parameter_data, axis, normalize_variance, eps)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
data_mean = np.mean(data_value, axis=(0, 2, 3), keepdims=1)
data_mean_squared = np.power(data_mean, 2)
data_squared = np.power(data_value, 2)
data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1)
std = np.sqrt(data_squared_mean - data_mean_squared)
expected = (data_value - data_mean) / (std + 1e-9)
assert np.allclose(result, expected)
def test_scale_shift_operator():
runtime = get_runtime()
data_shape = [3, 6]
scale_shape = [3, 6]
shift_shape = [1]
data_value = np.arange(start=19.0, stop=1.0, step=-1.0, dtype=np.float32).reshape(data_shape)
scale_value = np.arange(start=19.0, stop=1.0, step=-1.0, dtype=np.float32).reshape(scale_shape)
shift_value = [2.0]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_scale = ng.parameter(scale_shape, name='Scale', dtype=np.float32)
parameter_shift = ng.parameter(shift_shape, name='Shift', dtype=np.float32)
model = ng.scale_shift(parameter_data, parameter_scale, parameter_shift)
computation = runtime.computation(model, parameter_data, parameter_scale, parameter_shift)
result = computation(data_value, scale_value, shift_value)
expected = np.add(np.multiply(data_value, scale_value), shift_value)
assert np.allclose(result, expected)
def test_space_to_depth_operator():
runtime = get_runtime()
data_shape = [1, 2, 4, 4]
data_value = np.arange(start=0, stop=32, step=1.0, dtype=np.float32).reshape(data_shape)
block_size = 2
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.space_to_depth(parameter_data, block_size)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([0, 2, 8, 10, 16, 18, 24, 26,
1, 3, 9, 11, 17, 19, 25, 27,
4, 6, 12, 14, 20, 22, 28, 30,
5, 7, 13, 15, 21, 23, 29, 31], dtype=np.float32).reshape(1, 8, 2, 2)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment