Commit 8d1e2196 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added operators GroupConvolution and RNNCell to Python API. (#3425)

* [Py] Added operators GroupConvolution and RNNCell to Pythpn API

* [Py] Removed code unrealated to branch.

* [Py] Undo removed test.

* [Py] Added removed decorator.

* [Py] Code formatting.

* [Py] Added ops to documentation's list.

* [Py] Added skipped file.

* [Py] Code formatting.

* [Py] Code formatting.

* Revert "Merge branch 'master' into etusien/GroupConv_RNNCell"

This reverts commit a1848ea48916b293d5260869b2a52827bea21981, reversing
changes made to 6a60068abf8e5391bf875ee22573eb1aa388b047.

* [Py] Reverted changes.

* [Py] Changed imports' list

* [Py] Added missed imports.

* [Py] Added operators GroupConvolution and RNNCell to Pythpn API

* [Py] Removed code unrealated to branch.

* [Py] Undo removed test.

* [Py] Added removed decorator.

* [Py] Added ops to documentation's list.

* [Py] Added skipped file.

* [Py] Code formatting.

* [Py] Code formatting.

* [Py] Reverted changes.

* Revert "Revert "Merge branch 'master' into etusien/GroupConv_RNNCell""

This reverts commit 9c46ce5d289dadc4979e4712c79fff84bb538652.

* [Py] Reverted changes.

* [Py] Code formatting.

* [Py] Code formatting.

* [Py] Added PadType to Group Conv op.
parent 47a727a6
...@@ -46,6 +46,7 @@ ngraph.ops ...@@ -46,6 +46,7 @@ ngraph.ops
greater greater
greater_eq greater_eq
grn grn
group_convolution
hard_sigmoid hard_sigmoid
less less
less_eq less_eq
...@@ -76,6 +77,7 @@ ngraph.ops ...@@ -76,6 +77,7 @@ ngraph.ops
replace_slice replace_slice
reshape reshape
reverse reverse
rnn_cell
scale_shift scale_shift
select select
shuffle_channels shuffle_channels
......
...@@ -59,6 +59,7 @@ from ngraph.ops import get_output_element ...@@ -59,6 +59,7 @@ from ngraph.ops import get_output_element
from ngraph.ops import greater from ngraph.ops import greater
from ngraph.ops import greater_eq from ngraph.ops import greater_eq
from ngraph.ops import grn from ngraph.ops import grn
from ngraph.ops import group_convolution
from ngraph.ops import hard_sigmoid from ngraph.ops import hard_sigmoid
from ngraph.ops import less from ngraph.ops import less
from ngraph.ops import less_eq from ngraph.ops import less_eq
...@@ -89,6 +90,7 @@ from ngraph.ops import relu ...@@ -89,6 +90,7 @@ from ngraph.ops import relu
from ngraph.ops import replace_slice from ngraph.ops import replace_slice
from ngraph.ops import reshape from ngraph.ops import reshape
from ngraph.ops import reverse from ngraph.ops import reverse
from ngraph.ops import rnn_cell
from ngraph.ops import scale_shift from ngraph.ops import scale_shift
from ngraph.ops import select from ngraph.ops import select
from ngraph.ops import shuffle_channels from ngraph.ops import shuffle_channels
......
...@@ -83,6 +83,7 @@ from _pyngraph.op import GetOutputElement ...@@ -83,6 +83,7 @@ from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq from _pyngraph.op import GreaterEq
from _pyngraph.op import GRN from _pyngraph.op import GRN
from _pyngraph.op import GroupConvolution
from _pyngraph.op import HardSigmoid from _pyngraph.op import HardSigmoid
from _pyngraph.op import Less from _pyngraph.op import Less
from _pyngraph.op import LessEq from _pyngraph.op import LessEq
...@@ -115,6 +116,7 @@ from _pyngraph.op import ReluBackprop ...@@ -115,6 +116,7 @@ from _pyngraph.op import ReluBackprop
from _pyngraph.op import ReplaceSlice from _pyngraph.op import ReplaceSlice
from _pyngraph.op import Reshape from _pyngraph.op import Reshape
from _pyngraph.op import Reverse from _pyngraph.op import Reverse
from _pyngraph.op import RNNCell
from _pyngraph.op import ScaleShift from _pyngraph.op import ScaleShift
from _pyngraph.op import Select from _pyngraph.op import Select
from _pyngraph.op import ShuffleChannels from _pyngraph.op import ShuffleChannels
......
...@@ -24,11 +24,12 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP ...@@ -24,11 +24,12 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Dequantize, Divide, Dot, Elu, \ Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Dequantize, Divide, Dot, Elu, \
FakeQuantize, Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, \ FakeQuantize, Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, \
HardSigmoid, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, \ GroupConvolution, HardSigmoid, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, \
Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, Quantize, \ Multiply, MVN, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, \
QuantizedConvolution, QuantizedDot, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \ Quantize, QuantizedConvolution, QuantizedDot, PRelu, Relu, RNNCell, ReplaceSlice, Reshape, \
ScaleShift, Select, ShuffleChannels, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, \ Reverse, ScaleShift, Select, ShuffleChannels, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, \
SquaredDifference, Squeeze, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze Sqrt, SquaredDifference, Squeeze, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Set, Union from typing import Callable, Iterable, List, Set, Union
...@@ -189,6 +190,106 @@ def grn(data, bias, name=None): # type: (Node, float, str) -> Node ...@@ -189,6 +190,106 @@ def grn(data, bias, name=None): # type: (Node, float, str) -> Node
return GRN(data, bias) return GRN(data, bias)
@nameable_op
def group_convolution(data_batch, # type: Node
filters, # type: Node
window_movement_strides, # type: List[int]
window_dilation_strides, # type: List[int]
padding_below, # type: List[int]
padding_above, # type: List[int]
data_dilation_strides, # type: List[int]
groups, # type: int
pad_type='EXPLICIT', # type: str
name=None, # type: str
):
# type: (...) -> Node
"""Perform Group Convolution operation on data from input node.
:param data: The node producing input data.
:param filters: The node producing filters data.
:param window_movement_strides: The strides along each feature axis.
:param window_dilation_strides: The dilations along each feature axis.
:param padding_below: The padding added below each feature axis.
:param padding_above: The padding added above each feature axis.
:data_dilation_strides: The dilations along data.
:param groups: The number of groups the input channels and output channels
are divided into.
:param pad_type: Name describes how to perform padding.
EXPLICITI: Pad dimensions are explicity specified
SAME_LOWER: Pad dimensions computed to match input shape
Ceil(num_dims/2) at the beginning and
Floor(num_dims/2) at the end
SAME_UPPER: Pad dimensions computed to match input shape
Floor(num_dims/2) at the beginning and
Ceil(num_dims/2) at the end
VALID: No padding
:param name: Optional output node name.
:return: The new node performing a Group Convolution operation on tensor from input node.
"""
return GroupConvolution(data_batch,
filters,
Strides(window_movement_strides),
Strides(window_dilation_strides),
CoordinateDiff(padding_below),
CoordinateDiff(padding_above),
Strides(data_dilation_strides),
groups,
GroupConvolution.PadType(pad_type))
@nameable_op
def rnn_cell(X, # type: Node
W, # type: Node
R, # type: Node
H_t, # type: Node
hidden_size, # type: int
B, # type: Node
activations, # type: List[str]
activation_alpha, # type: List[float]
activation_beta, # type: List[float]
clip, # type: float
name=None, # type: str
):
# type: (...) -> Node
"""Perform RNNCell operation on tensor from input node.
It follows notation and equations defined as in ONNX standard:
https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
Note this class represents only single *cell* and not whole RNN *layer*.
:param X: The input tensor with shape: [batch_size, input_size].
:param W: The weight tensor with shape: [hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [hidden_size, hidden_size].
:param H_t: The hidden state tensor at current time step with
shape: [batch_size, hidden_size].
:param hidden_size: The number of hidden units for recurrent cell.
:param B: The bias tensor for input gate with shape: [2*hidden_size].
:param activations: The vector of activation functions used inside recurrent cell.
:param activation_alpha: The vector of alpha parameters for activation
functions in order respective to activation list.
:param activation_beta: The vector of beta parameters for activation functions
in order respective to activation list.
:param clip: The value defining clipping range [-clip, clip] on
input of activation functions.
:param name: Optional output node name.
:return: The new node performing a RNNCell operation on tensor from input node.
"""
return RNNCell(X,
W,
R,
H_t,
hidden_size,
B,
activations,
activation_alpha,
activation_beta,
clip)
@nameable_op @nameable_op
def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str) -> Node def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str) -> Node
r"""Perform ScaleShift transformation on input node. r"""Perform ScaleShift transformation on input node.
...@@ -201,7 +302,7 @@ def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str) ...@@ -201,7 +302,7 @@ def scale_shift(data, scale, shift, name=None): # type: (Node, Node, Node, str)
:param data: The node with data tensor. :param data: The node with data tensor.
:param scale: The node with data tensor that scale input data. :param scale: The node with data tensor that scale input data.
:param shift: The node with data tensor that shift input data. :param shift: The node with data tensor that shift input data.
:param name: Optional output node name.spa :param name: Optional output node name.
:return: The new node performing a ScaleShift operation on input tensor. :return: The new node performing a ScaleShift operation on input tensor.
""" """
return ScaleShift(data, scale, shift) return ScaleShift(data, scale, shift)
...@@ -892,7 +993,6 @@ def fake_quantize(data, input_low, input_high, output_low, output_high, levels, ...@@ -892,7 +993,6 @@ def fake_quantize(data, input_low, input_high, output_low, output_high, levels,
Input floating point values are quantized into a discrete set of floating point values. Input floating point values are quantized into a discrete set of floating point values.
.. code-block:: python .. code-block:: python
if x <= input_low: if x <= input_low:
output = output_low output = output_low
if x > input_high: if x > input_high:
...@@ -917,6 +1017,7 @@ def fake_quantize(data, input_low, input_high, output_low, output_high, levels, ...@@ -917,6 +1017,7 @@ def fake_quantize(data, input_low, input_high, output_low, output_high, levels,
return FakeQuantize(data, input_low, input_high, output_low, output_high, levels) return FakeQuantize(data, input_low, input_high, output_low, output_high, levels)
@nameable_op
def gemm(A, # type: Node def gemm(A, # type: Node
B, # type: Node B, # type: Node
C, # type: Node C, # type: Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/group_conv.hpp"
#include "pyngraph/ops/fused/group_conv.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_GroupConvolution(py::module m)
{
py::class_<ngraph::op::GroupConvolution,
std::shared_ptr<ngraph::op::GroupConvolution>,
ngraph::op::Op>
groupconvolution(m, "GroupConvolution");
groupconvolution.doc() = "ngraph.impl.op.GroupConvolution wraps ngraph::op::GroupConvolution";
groupconvolution.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::Strides&,
const ngraph::Strides&,
const ngraph::CoordinateDiff&,
const ngraph::CoordinateDiff&,
const ngraph::Strides&,
const size_t,
const ngraph::op::PadType&>());
py::enum_<ngraph::op::PadType>(groupconvolution, "PadType", py::arithmetic())
.value("EXPLICIT", ngraph::op::PadType::EXPLICIT)
.value("SAME_LOWER", ngraph::op::PadType::SAME_LOWER)
.value("SAME_UPPER", ngraph::op::PadType::SAME_UPPER)
.value("VALID", ngraph::op::PadType::VALID)
.value("AUTO", ngraph::op::PadType::AUTO)
.value("NOTSET", ngraph::op::PadType::NOTSET)
.export_values();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_GroupConvolution(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/rnn_cell.hpp"
#include "pyngraph/ops/fused/rnn_cell.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_RNNCell(py::module m)
{
py::class_<ngraph::op::RNNCell, std::shared_ptr<ngraph::op::RNNCell>, ngraph::op::Op> rnncell(
m, "RNNCell");
rnncell.doc() = "ngraph.impl.op.RNNCell wraps ngraph::op::RNNCell";
rnncell.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
int&,
const std::shared_ptr<ngraph::Node>&,
const std::vector<std::string>&,
const std::vector<float>&,
const std::vector<float>&,
float&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_RNNCell(py::module m);
...@@ -63,6 +63,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -63,6 +63,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Greater(m_op); regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op); regclass_pyngraph_op_GreaterEq(m_op);
regclass_pyngraph_op_GRN(m_op); regclass_pyngraph_op_GRN(m_op);
regclass_pyngraph_op_GroupConvolution(m_op);
regclass_pyngraph_op_HardSigmoid(m_op); regclass_pyngraph_op_HardSigmoid(m_op);
regclass_pyngraph_op_Less(m_op); regclass_pyngraph_op_Less(m_op);
regclass_pyngraph_op_LessEq(m_op); regclass_pyngraph_op_LessEq(m_op);
...@@ -95,6 +96,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -95,6 +96,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_ReplaceSlice(m_op); regclass_pyngraph_op_ReplaceSlice(m_op);
regclass_pyngraph_op_Reshape(m_op); regclass_pyngraph_op_Reshape(m_op);
regclass_pyngraph_op_Reverse(m_op); regclass_pyngraph_op_Reverse(m_op);
regclass_pyngraph_op_RNNCell(m_op);
regclass_pyngraph_op_ScaleShift(m_op); regclass_pyngraph_op_ScaleShift(m_op);
regclass_pyngraph_op_Select(m_op); regclass_pyngraph_op_Select(m_op);
regclass_pyngraph_op_ShuffleChannels(m_op); regclass_pyngraph_op_ShuffleChannels(m_op);
......
...@@ -50,9 +50,11 @@ ...@@ -50,9 +50,11 @@
#include "pyngraph/ops/fused/gelu.hpp" #include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp" #include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp" #include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/fused/group_conv.hpp"
#include "pyngraph/ops/fused/hard_sigmoid.hpp" #include "pyngraph/ops/fused/hard_sigmoid.hpp"
#include "pyngraph/ops/fused/mvn.hpp" #include "pyngraph/ops/fused/mvn.hpp"
#include "pyngraph/ops/fused/prelu.hpp" #include "pyngraph/ops/fused/prelu.hpp"
#include "pyngraph/ops/fused/rnn_cell.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp" #include "pyngraph/ops/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/shuffle_channels.hpp" #include "pyngraph/ops/fused/shuffle_channels.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp" #include "pyngraph/ops/fused/space_to_depth.hpp"
......
...@@ -192,6 +192,7 @@ sources = [ ...@@ -192,6 +192,7 @@ sources = [
'pyngraph/ops/greater.cpp', 'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp', 'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/fused/grn.cpp', 'pyngraph/ops/fused/grn.cpp',
'pyngraph/ops/fused/group_conv.cpp',
'pyngraph/ops/fused/hard_sigmoid.cpp', 'pyngraph/ops/fused/hard_sigmoid.cpp',
'pyngraph/ops/less.cpp', 'pyngraph/ops/less.cpp',
'pyngraph/ops/less_eq.cpp', 'pyngraph/ops/less_eq.cpp',
...@@ -223,6 +224,7 @@ sources = [ ...@@ -223,6 +224,7 @@ sources = [
'pyngraph/ops/replace_slice.cpp', 'pyngraph/ops/replace_slice.cpp',
'pyngraph/ops/reshape.cpp', 'pyngraph/ops/reshape.cpp',
'pyngraph/ops/reverse.cpp', 'pyngraph/ops/reverse.cpp',
'pyngraph/ops/fused/rnn_cell.cpp',
'pyngraph/ops/fused/scale_shift.cpp', 'pyngraph/ops/fused/scale_shift.cpp',
'pyngraph/ops/select.cpp', 'pyngraph/ops/select.cpp',
'pyngraph/ops/fused/shuffle_channels.cpp', 'pyngraph/ops/fused/shuffle_channels.cpp',
......
...@@ -441,3 +441,104 @@ def test_space_to_depth_operator(): ...@@ -441,3 +441,104 @@ def test_space_to_depth_operator():
4, 6, 12, 14, 20, 22, 28, 30, 4, 6, 12, 14, 20, 22, 28, 30,
5, 7, 13, 15, 21, 23, 29, 31], dtype=np.float32).reshape(1, 8, 2, 2) 5, 7, 13, 15, 21, 23, 29, 31], dtype=np.float32).reshape(1, 8, 2, 2)
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_rnn_cell_operator():
runtime = get_runtime()
batch_size = 2
input_size = 3
hidden_size = 3
X_shape = [batch_size, input_size]
W_shape = [hidden_size, input_size]
R_shape = [hidden_size, hidden_size]
H_t_shape = [batch_size, hidden_size]
B_shape = [2 * hidden_size]
parameter_X = ng.parameter(X_shape, name='X', dtype=np.float32)
parameter_W = ng.parameter(W_shape, name='W', dtype=np.float32)
parameter_R = ng.parameter(R_shape, name='R', dtype=np.float32)
parameter_H_t = ng.parameter(H_t_shape, name='H_t', dtype=np.float32)
parameter_B = ng.parameter(B_shape, name='B', dtype=np.float32)
X_value = np.array([0.3432185, 0.612268, 0.20272376,
0.9513413, 0.30585995, 0.7265472],
dtype=np.float32).reshape(X_shape)
W_value = np.array([0.41930267, 0.7872176, 0.89940447,
0.23659843, 0.24676207, 0.17101714,
0.3147149, 0.6555601, 0.4559603],
dtype=np.float32).reshape(W_shape)
R_value = np.array([0.8374871, 0.86660194, 0.82114047,
0.71549815, 0.18775631, 0.3182116,
0.25392973, 0.38301638, 0.85531586],
dtype=np.float32).reshape(R_shape)
H_t_value = np.array([0.12444675, 0.52055854, 0.46489045,
0.4983964, 0.7730452, 0.28439692],
dtype=np.float32).reshape(H_t_shape)
B_value = np.array([0.45513555, 0.96227735, 0.24737759,
0.57380486, 0.67398053, 0.18968852],
dtype=np.float32).reshape(B_shape)
activations = ['sigmoid']
activation_alpha = []
activation_beta = []
clip = 2.88
model = ng.rnn_cell(parameter_X,
parameter_W,
parameter_R,
parameter_H_t,
hidden_size,
parameter_B,
activations,
activation_alpha,
activation_beta,
clip)
computation = runtime.computation(model,
parameter_X,
parameter_W,
parameter_R,
parameter_H_t,
parameter_B)
result = computation(X_value, W_value, R_value, H_t_value, B_value)
expected = np.array([0.94126844, 0.9036043, 0.841243,
0.9468489, 0.934215, 0.873708],
dtype=np.float32).reshape(batch_size, hidden_size)
assert np.allclose(result, expected)
def test_group_convolution_operator():
runtime = get_runtime()
data_shape = [1, 4, 2, 2]
filters_shape = [2, 2, 1, 1]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_filters = ng.parameter(filters_shape, name='Filters', dtype=np.float32)
data_value = np.arange(start=1.0, stop=17.0, dtype=np.float32).reshape(data_shape)
filters_value = np.arange(start=1.0, stop=5.0, dtype=np.float32).reshape(filters_shape)
window_movement_strides = [1, 1]
window_dilation_strides = [1, 1]
padding_below = [0, 0]
padding_above = [0, 0]
data_dilation_strides = [1, 1]
groups = 2
model = ng.group_convolution(parameter_data,
parameter_filters,
window_movement_strides,
window_dilation_strides,
padding_below, padding_above,
data_dilation_strides,
groups,
0)
computation = runtime.computation(model, parameter_data, parameter_filters)
result = computation(data_value, filters_value)
expected = np.array([11, 14, 17, 20, 79, 86, 93, 100],
dtype=np.float32).reshape(1, 2, 2, 2)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment