Commit 5c56923a authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[Py] Add convolution_backprop_data to API (#1292)

* [Py] Add convolution_backprop_data to API

* Conv fix
parent bb94fa85
......@@ -31,6 +31,7 @@ from ngraph.ops import concat
from ngraph.ops import constant
from ngraph.ops import convert
from ngraph.ops import convolution
from ngraph.ops import convolution_backprop_data
from ngraph.ops import cos
from ngraph.ops import cosh
from ngraph.ops import divide
......
......@@ -21,11 +21,11 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \
Ceiling, Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, \
Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \
Subtract, Sum, Tan, Tanh
Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, \
Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, \
Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, \
Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from typing import Callable, Iterable, List, Union
......@@ -600,6 +600,49 @@ def convolution(data_batch, # type: Node
CoordinateDiff(padding_above), Strides(data_dilation_strides))
@nameable_op
def convolution_backprop_data(data_batch_shape, # type: TensorShape
filters, # type: Node
output_delta, # type: Node
window_movement_strides_forward=None, # type: List[int]
window_dilation_strides_forward=None, # type: List[int]
padding_below_forward=None, # type: List[int]
padding_above_forward=None, # type: List[int]
data_dilation_strides_forward=None, # type: List[int]
name=None, # type: str
):
# type: (...) -> Node
"""Return node performing a batched-convolution data batch-backprop operation.
:param data_batch_shape: The shape of the data batch from forward-prop.
:param filters: The node producing the filters from forward-prop.
:param output_delta: The node producing output delta.
:param window_movement_strides_forward: The window movement strides from forward-prop.
:param window_dilation_strides_forward: The window dilation strides from forward-prop.
:param padding_below_forward: The padding-below sizes from forward-prop.
:param padding_above_forward: The padding-above sizes from forward-prop.
:param data_dilation_strides_forward: The data dilation strides from forward-prop.
"""
spatial_dim_count = len(data_batch_shape) - 2
if window_movement_strides_forward is None:
window_movement_strides_forward = [1] * spatial_dim_count
if window_dilation_strides_forward is None:
window_dilation_strides_forward = [1] * spatial_dim_count
if padding_below_forward is None:
padding_below_forward = [0] * spatial_dim_count
if padding_above_forward is None:
padding_above_forward = [0] * spatial_dim_count
if data_dilation_strides_forward is None:
data_dilation_strides_forward = [1] * spatial_dim_count
return ConvolutionBackpropData(Shape(data_batch_shape), filters, output_delta,
Strides(window_movement_strides_forward),
Strides(window_dilation_strides_forward),
CoordinateDiff(padding_below_forward),
CoordinateDiff(padding_above_forward),
Strides(data_dilation_strides_forward))
@nameable_op
def avg_pool(data_batch, # type: Node
window_shape, # type: TensorShape
......
......@@ -93,3 +93,45 @@ def test_convolution_2d():
[0, 0, 20, 20, 0],
[0, 0, 20, 20, 0]]]],
dtype=np.float32))
@pytest.config.gpu_skip(reason='Not implemented')
def test_convolution_backprop_data():
runtime = get_runtime()
data_batch_shape = [1, 1, 9, 9]
filter_shape = [1, 1, 3, 3]
output_delta_shape = [1, 1, 7, 7]
filter_param = ng.parameter(shape=filter_shape)
output_delta_param = ng.parameter(shape=output_delta_shape)
deconvolution = ng.convolution_backprop_data(data_batch_shape, filter_param, output_delta_param)
data_batch_data = np.array([[[[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0]]]],
dtype=np.float32)
filter_data = np.array([
[1., 0., -1.],
[2., 0., -2.],
[1., 0., -1.]], dtype=np.float32).reshape(1, 1, 3, 3)
model = runtime.computation(deconvolution, filter_param, output_delta_param)
result = model(filter_data, data_batch_data)
assert np.allclose(result,
np.array([[[[-20., -20., 40., 40., -20., -20., 0., 0., 0.],
[-60., -60., 120., 120., -60., -60., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-60., -60., 120., 120., -60., -60., 0., 0., 0.],
[-20., -20., 40., 40., -20., -20., 0., 0., 0.]]]],
dtype=np.float32))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment