Commit 5c56923a authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[Py] Add convolution_backprop_data to API (#1292)

* [Py] Add convolution_backprop_data to API

* Conv fix
parent bb94fa85
...@@ -31,6 +31,7 @@ from ngraph.ops import concat ...@@ -31,6 +31,7 @@ from ngraph.ops import concat
from ngraph.ops import constant from ngraph.ops import constant
from ngraph.ops import convert from ngraph.ops import convert
from ngraph.ops import convolution from ngraph.ops import convolution
from ngraph.ops import convolution_backprop_data
from ngraph.ops import cos from ngraph.ops import cos
from ngraph.ops import cosh from ngraph.ops import cosh
from ngraph.ops import divide from ngraph.ops import divide
......
...@@ -21,11 +21,11 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -21,11 +21,11 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
NodeVector, Shape, Strides NodeVector, Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, Atan, AvgPool, BatchNorm, Broadcast, \
Ceiling, Concat, Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ Ceiling, Concat, Constant, Convert, Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, \
FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, Max, Maximum, MaxPool, \ Dot, Equal, Exp, Floor, FunctionCall, GetOutputElement, Greater, GreaterEq, Less, LessEq, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, Power, \ Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, Sqrt, \ Parameter, Product, Power, Reduce, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, \
Subtract, Sum, Tan, Tanh Sinh, Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
...@@ -600,6 +600,49 @@ def convolution(data_batch, # type: Node ...@@ -600,6 +600,49 @@ def convolution(data_batch, # type: Node
CoordinateDiff(padding_above), Strides(data_dilation_strides)) CoordinateDiff(padding_above), Strides(data_dilation_strides))
@nameable_op
def convolution_backprop_data(data_batch_shape, # type: TensorShape
filters, # type: Node
output_delta, # type: Node
window_movement_strides_forward=None, # type: List[int]
window_dilation_strides_forward=None, # type: List[int]
padding_below_forward=None, # type: List[int]
padding_above_forward=None, # type: List[int]
data_dilation_strides_forward=None, # type: List[int]
name=None, # type: str
):
# type: (...) -> Node
"""Return node performing a batched-convolution data batch-backprop operation.
:param data_batch_shape: The shape of the data batch from forward-prop.
:param filters: The node producing the filters from forward-prop.
:param output_delta: The node producing output delta.
:param window_movement_strides_forward: The window movement strides from forward-prop.
:param window_dilation_strides_forward: The window dilation strides from forward-prop.
:param padding_below_forward: The padding-below sizes from forward-prop.
:param padding_above_forward: The padding-above sizes from forward-prop.
:param data_dilation_strides_forward: The data dilation strides from forward-prop.
"""
spatial_dim_count = len(data_batch_shape) - 2
if window_movement_strides_forward is None:
window_movement_strides_forward = [1] * spatial_dim_count
if window_dilation_strides_forward is None:
window_dilation_strides_forward = [1] * spatial_dim_count
if padding_below_forward is None:
padding_below_forward = [0] * spatial_dim_count
if padding_above_forward is None:
padding_above_forward = [0] * spatial_dim_count
if data_dilation_strides_forward is None:
data_dilation_strides_forward = [1] * spatial_dim_count
return ConvolutionBackpropData(Shape(data_batch_shape), filters, output_delta,
Strides(window_movement_strides_forward),
Strides(window_dilation_strides_forward),
CoordinateDiff(padding_below_forward),
CoordinateDiff(padding_above_forward),
Strides(data_dilation_strides_forward))
@nameable_op @nameable_op
def avg_pool(data_batch, # type: Node def avg_pool(data_batch, # type: Node
window_shape, # type: TensorShape window_shape, # type: TensorShape
......
...@@ -93,3 +93,45 @@ def test_convolution_2d(): ...@@ -93,3 +93,45 @@ def test_convolution_2d():
[0, 0, 20, 20, 0], [0, 0, 20, 20, 0],
[0, 0, 20, 20, 0]]]], [0, 0, 20, 20, 0]]]],
dtype=np.float32)) dtype=np.float32))
@pytest.config.gpu_skip(reason='Not implemented')
def test_convolution_backprop_data():
runtime = get_runtime()
data_batch_shape = [1, 1, 9, 9]
filter_shape = [1, 1, 3, 3]
output_delta_shape = [1, 1, 7, 7]
filter_param = ng.parameter(shape=filter_shape)
output_delta_param = ng.parameter(shape=output_delta_shape)
deconvolution = ng.convolution_backprop_data(data_batch_shape, filter_param, output_delta_param)
data_batch_data = np.array([[[[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0],
[-20, -20, 20, 20, 0, 0, 0]]]],
dtype=np.float32)
filter_data = np.array([
[1., 0., -1.],
[2., 0., -2.],
[1., 0., -1.]], dtype=np.float32).reshape(1, 1, 3, 3)
model = runtime.computation(deconvolution, filter_param, output_delta_param)
result = model(filter_data, data_batch_data)
assert np.allclose(result,
np.array([[[[-20., -20., 40., 40., -20., -20., 0., 0., 0.],
[-60., -60., 120., 120., -60., -60., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-80., -80., 160., 160., -80., -80., 0., 0., 0.],
[-60., -60., 120., 120., -60., -60., 0., 0., 0.],
[-20., -20., 40., 40., -20., -20., 0., 0., 0.]]]],
dtype=np.float32))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment