Unverified Commit c5144d48 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Negative convolution padding (#396)

parent 68ef3faa
# API Changes # API Changes
`Parameter` and `Function` no longer take a type argument.
## Negative convolution padding
`Convolution` now allows negative padding. This means that the `padding_below` and `padding_above`
arguments now take type `CoordinateDiff` instead of `Shape`. `CoordinateDiff` is an alias for
`std::vector<std::ptrdiff_t>`, which "is like `size_t` but is allowed to be negative". Callers may
need to be adapted.
## `Parameter` and `Function` no longer take a type argument.
To update, remove the passed argument. For example, To update, remove the passed argument. For example,
```C++ ```C++
// Old // Old
......
...@@ -53,6 +53,13 @@ namespace ngraph ...@@ -53,6 +53,13 @@ namespace ngraph
/// @brief Strides of a tensor /// @brief Strides of a tensor
using Strides = std::vector<size_t>; using Strides = std::vector<size_t>;
/// @brief A coordinate-like type whose elements are allowed to be
/// negative.
///
/// Currently used only to express negative padding; in the future,
/// could conceivably be used to express
using CoordinateDiff = std::vector<std::ptrdiff_t>;
Coordinate project_coordinate(const Coordinate& coord, const AxisSet& deleted_axes); Coordinate project_coordinate(const Coordinate& coord, const AxisSet& deleted_axes);
Shape project_shape(const Shape& shape, const AxisSet& deleted_axes); Shape project_shape(const Shape& shape, const AxisSet& deleted_axes);
......
...@@ -30,8 +30,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -30,8 +30,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& target_padding_below, const CoordinateDiff& target_padding_below,
const Shape& target_padding_above, const CoordinateDiff& target_padding_above,
const Strides& target_dilation_strides) const Strides& target_dilation_strides)
: m_source_shape(source_shape) : m_source_shape(source_shape)
, m_source_start_corner(source_start_corner) , m_source_start_corner(source_start_corner)
...@@ -96,48 +96,61 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -96,48 +96,61 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_start_corner[i] >= (source_shape[i] - 1) * target_dilation_strides[i] + 1 + if (target_dilation_strides[i] == 0)
target_padding_below[i] + target_padding_above[i] &&
!(source_start_corner[i] == 0 && source_shape[i] == 0))
{ {
std::stringstream ss; std::stringstream ss;
ss << "The start corner is out of bounds at axis " << i; ss << "The target dilation stride is 0 at axis " << i;
throw std::domain_error(ss.str()); throw std::domain_error(ss.str());
} }
} }
std::vector<std::ptrdiff_t> padded_upper_bounds;
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_end_corner[i] > std::ptrdiff_t padded_upper_bound =
subtract_or_zero(source_shape[i], size_t(1)) * target_dilation_strides[i] + 1 + subtract_or_zero(source_shape[i], size_t(1)) * target_dilation_strides[i] + 1 +
target_padding_below[i] + target_padding_above[i]) target_padding_below[i] + target_padding_above[i];
if (padded_upper_bound < 0)
{ {
std::stringstream ss; std::stringstream ss;
ss << "The end corner is out of bounds at axis " << i; ss << "The end corner is out of bounds at axis " << i;
throw std::domain_error(ss.str()); throw std::domain_error(ss.str());
} }
padded_upper_bounds.push_back(padded_upper_bound);
} }
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_strides[i] == 0) if (source_start_corner[i] >= padded_upper_bounds[i] &&
!(source_start_corner[i] == 0 && source_shape[i] == 0))
{ {
std::stringstream ss; std::stringstream ss;
ss << "The source stride is 0 at axis " << i; ss << "The start corner is out of bounds at axis " << i;
throw std::domain_error(ss.str());
}
if (source_end_corner[i] > padded_upper_bounds[i])
{
std::stringstream ss;
ss << "The end corner is out of bounds at axis " << i;
throw std::domain_error(ss.str()); throw std::domain_error(ss.str());
} }
} }
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (target_dilation_strides[i] == 0) if (source_strides[i] == 0)
{ {
std::stringstream ss; std::stringstream ss;
ss << "The target dilation stride is 0 at axis " << i; ss << "The source stride is 0 at axis " << i;
throw std::domain_error(ss.str()); throw std::domain_error(ss.str());
} }
} }
...@@ -160,8 +173,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -160,8 +173,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& target_padding_below, const CoordinateDiff& target_padding_below,
const Shape& target_padding_above) const CoordinateDiff& target_padding_above)
: CoordinateTransform(source_shape, : CoordinateTransform(source_shape,
source_start_corner, source_start_corner,
source_end_corner, source_end_corner,
...@@ -173,9 +186,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -173,9 +186,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
{ {
} }
Shape CoordinateTransform::default_padding(size_t n_axes) CoordinateDiff CoordinateTransform::default_padding(size_t n_axes)
{ {
return Shape(n_axes, 0); return CoordinateDiff(n_axes, 0);
} }
CoordinateTransform::CoordinateTransform(const Shape& source_shape, CoordinateTransform::CoordinateTransform(const Shape& source_shape,
...@@ -321,18 +334,18 @@ bool CoordinateTransform::has_source_coordinate(const Coordinate& c_target) cons ...@@ -321,18 +334,18 @@ bool CoordinateTransform::has_source_coordinate(const Coordinate& c_target) cons
// The rest of this is a replay of the corresponding logic in `to_source_coordinate`, with // The rest of this is a replay of the corresponding logic in `to_source_coordinate`, with
// bounds and divisibility checking. // bounds and divisibility checking.
size_t source_axis = m_source_axis_order[target_axis]; std::ptrdiff_t source_axis = m_source_axis_order[target_axis];
size_t target_pos = c_target[target_axis]; std::ptrdiff_t target_pos = c_target[target_axis];
size_t pos_destrided = target_pos * m_source_strides[source_axis]; std::ptrdiff_t pos_destrided = target_pos * m_source_strides[source_axis];
size_t pos_deshifted = pos_destrided + m_source_start_corner[source_axis]; std::ptrdiff_t pos_deshifted = pos_destrided + m_source_start_corner[source_axis];
// If we are in the below-padding or the above-padding. // If we are in the below-padding or the above-padding.
if (pos_deshifted < m_target_padding_below[target_axis]) if (pos_deshifted < m_target_padding_below[target_axis])
{ {
return false; return false;
} }
size_t pos_depadded = pos_deshifted - m_target_padding_below[target_axis]; std::ptrdiff_t pos_depadded = pos_deshifted - m_target_padding_below[target_axis];
// If we are in the above-padding, we have no source coordinate. // If we are in the above-padding, we have no source coordinate.
if (m_source_shape[source_axis] == 0 || if (m_source_shape[source_axis] == 0 ||
......
...@@ -31,8 +31,8 @@ namespace ngraph ...@@ -31,8 +31,8 @@ namespace ngraph
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& target_padding_below, const CoordinateDiff& target_padding_below,
const Shape& target_padding_above, const CoordinateDiff& target_padding_above,
const Strides& source_dilation_strides); const Strides& source_dilation_strides);
CoordinateTransform(const Shape& source_shape, CoordinateTransform(const Shape& source_shape,
...@@ -40,8 +40,8 @@ namespace ngraph ...@@ -40,8 +40,8 @@ namespace ngraph
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& target_padding_below, const CoordinateDiff& target_padding_below,
const Shape& target_padding_above); const CoordinateDiff& target_padding_above);
CoordinateTransform(const Shape& source_shape, CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner, const Coordinate& source_start_corner,
...@@ -96,7 +96,7 @@ namespace ngraph ...@@ -96,7 +96,7 @@ namespace ngraph
private: private:
size_t index_source(const Coordinate& c) const; size_t index_source(const Coordinate& c) const;
static Strides default_strides(size_t n_axes); static Strides default_strides(size_t n_axes);
static Shape default_padding(size_t n_axes); static CoordinateDiff default_padding(size_t n_axes);
static AxisVector default_axis_order(size_t n_axes); static AxisVector default_axis_order(size_t n_axes);
static Coordinate default_source_start_corner(size_t n_axes); static Coordinate default_source_start_corner(size_t n_axes);
static Coordinate default_source_end_corner(const Shape& source_shape); static Coordinate default_source_end_corner(const Shape& source_shape);
...@@ -106,8 +106,8 @@ namespace ngraph ...@@ -106,8 +106,8 @@ namespace ngraph
Shape m_source_end_corner; Shape m_source_end_corner;
Strides m_source_strides; Strides m_source_strides;
AxisVector m_source_axis_order; AxisVector m_source_axis_order;
Shape m_target_padding_below; CoordinateDiff m_target_padding_below;
Shape m_target_padding_above; CoordinateDiff m_target_padding_above;
Strides m_target_dilation_strides; Strides m_target_dilation_strides;
Shape m_target_shape; Shape m_target_shape;
......
...@@ -22,8 +22,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -22,8 +22,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const CoordinateDiff& padding_below,
const Shape& padding_above, const CoordinateDiff& padding_above,
const Strides& image_dilation_strides) const Strides& image_dilation_strides)
: RequiresTensorViewArgs("Convolution", {image_batch, filters}) : RequiresTensorViewArgs("Convolution", {image_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
...@@ -128,7 +128,16 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -128,7 +128,16 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
size_t dim_size = image_batch_shape[1 + 1 + i]; size_t dim_size = image_batch_shape[1 + 1 + i];
m_input_image_physical_shape.push_back(dim_size); m_input_image_physical_shape.push_back(dim_size);
size_t dilated_dim_size = (dim_size - 1) * image_dilation_strides[i] + 1; size_t dilated_dim_size = (dim_size - 1) * image_dilation_strides[i] + 1;
size_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
std::ptrdiff_t padded_dilated_dim_size =
padding_below[i] + dilated_dim_size + padding_above[i];
if (padded_dilated_dim_size < 0)
{
throw ngraph_error(
"Convolution input image dimension after padding and dilation is negative.");
}
m_input_image_virtual_shape.push_back(padded_dilated_dim_size); m_input_image_virtual_shape.push_back(padded_dilated_dim_size);
if (m_input_image_virtual_shape[i] == 0) if (m_input_image_virtual_shape[i] == 0)
...@@ -214,8 +223,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -214,8 +223,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const CoordinateDiff& padding_below,
const Shape& padding_above) const CoordinateDiff& padding_above)
: Convolution(image_batch, : Convolution(image_batch,
filters, filters,
window_movement_strides, window_movement_strides,
...@@ -226,7 +235,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -226,7 +235,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
{ {
} }
Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) CoordinateDiff op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -236,7 +245,7 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) ...@@ -236,7 +245,7 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
"Convolution image batch input must have rank of at least 3 (one batch axis, one " "Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension)."); "input-channel axis, at least one image dimension).");
} }
return Shape(image_batch_shape.size() - 2, 0); return CoordinateDiff(image_batch_shape.size() - 2, 0);
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
......
...@@ -34,8 +34,8 @@ namespace ngraph ...@@ -34,8 +34,8 @@ namespace ngraph
/// ///
/// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$ (default is all ones), /// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$ (default is all ones),
/// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$ (default is all ones), /// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$ (default is all ones),
/// 5. <i>(the padding below)</i> a vector of non-negative integers \f$(p_1,\dots,p_n)\f$ (default is all zeros), /// 5. <i>(the padding below)</i> a vector of (possibly negative) integers \f$(p_1,\dots,p_n)\f$ (default is all zeros),
/// 6. <i>(the padding above)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all zeros), and /// 6. <i>(the padding above)</i> a vector of (possibly negative) integers \f$(q_1,\dots,q_n)\f$ (default is all zeros), and
/// 7. <i>(the image dilation strides)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all ones). /// 7. <i>(the image dilation strides)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all ones).
/// ///
/// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{(d_i - 1) * g_i + 1 + p_i + q_i - l_i(d^f_i - 1)}{s_i} \rceil\f$. /// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{(d_i - 1) * g_i + 1 + p_i + q_i - l_i(d^f_i - 1)}{s_i} \rceil\f$.
...@@ -68,8 +68,8 @@ namespace ngraph ...@@ -68,8 +68,8 @@ namespace ngraph
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const CoordinateDiff& padding_below,
const Shape& padding_above, const CoordinateDiff& padding_above,
const Strides& image_dilation_strides); const Strides& image_dilation_strides);
/// \brief Constructs a batched convolution operation with no image dilation (i.e., all image dilation strides are 1). /// \brief Constructs a batched convolution operation with no image dilation (i.e., all image dilation strides are 1).
...@@ -84,8 +84,8 @@ namespace ngraph ...@@ -84,8 +84,8 @@ namespace ngraph
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const CoordinateDiff& padding_below,
const Shape& padding_above); const CoordinateDiff& padding_above);
/// \brief Constructs a batched convolution operation with no padding or image dilation (i.e., padding above and below are 0 everywhere, and all image dilation strides are 1). /// \brief Constructs a batched convolution operation with no padding or image dilation (i.e., padding above and below are 0 everywhere, and all image dilation strides are 1).
/// ///
...@@ -121,10 +121,10 @@ namespace ngraph ...@@ -121,10 +121,10 @@ namespace ngraph
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
/// \return The window dilation strides. /// \return The window dilation strides.
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
/// \return The padding-below sizes. /// \return The padding-below sizes (possibly negative).
const Shape& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
/// \return The padding-above sizes. /// \return The padding-above sizes (possibly negative).
const Shape& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
/// \return The input image dilation strides. /// \return The input image dilation strides.
const Strides& get_image_dilation_strides() const { return m_image_dilation_strides; } const Strides& get_image_dilation_strides() const { return m_image_dilation_strides; }
/// \return The number of input channels. /// \return The number of input channels.
...@@ -156,8 +156,8 @@ namespace ngraph ...@@ -156,8 +156,8 @@ namespace ngraph
protected: protected:
Strides m_window_movement_strides; Strides m_window_movement_strides;
Strides m_window_dilation_strides; Strides m_window_dilation_strides;
Shape m_padding_below; CoordinateDiff m_padding_below;
Shape m_padding_above; CoordinateDiff m_padding_above;
Strides m_image_dilation_strides; Strides m_image_dilation_strides;
// TODO: Some or all of these values should probably be computed dynamically rather than stored here. // TODO: Some or all of these values should probably be computed dynamically rather than stored here.
...@@ -173,7 +173,7 @@ namespace ngraph ...@@ -173,7 +173,7 @@ namespace ngraph
private: private:
static Strides default_strides(const std::shared_ptr<Node>& image_batch); static Strides default_strides(const std::shared_ptr<Node>& image_batch);
static Shape default_padding(const std::shared_ptr<Node>& image_batch); static CoordinateDiff default_padding(const std::shared_ptr<Node>& image_batch);
}; };
} }
} }
...@@ -63,12 +63,12 @@ namespace ngraph ...@@ -63,12 +63,12 @@ namespace ngraph
size_t n_image_dimensions = arg_shape.size() - 2; size_t n_image_dimensions = arg_shape.size() - 2;
Shape input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_image_dimensions);
Shape input_batch_transform_source_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_source_strides(2 + n_image_dimensions, 1);
Shape input_batch_transform_source_axis_order(2 + n_image_dimensions); AxisVector input_batch_transform_source_axis_order(2 + n_image_dimensions);
Shape input_batch_transform_padding_below(2 + n_image_dimensions); CoordinateDiff input_batch_transform_padding_below(2 + n_image_dimensions);
Shape input_batch_transform_padding_above(2 + n_image_dimensions); CoordinateDiff input_batch_transform_padding_above(2 + n_image_dimensions);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
......
...@@ -35,8 +35,8 @@ namespace ngraph ...@@ -35,8 +35,8 @@ namespace ngraph
const Shape& out_shape, const Shape& out_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const CoordinateDiff& padding_below,
const Shape& padding_above, const CoordinateDiff& padding_above,
const Strides& image_dilation_strides) const Strides& image_dilation_strides)
{ {
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
...@@ -71,12 +71,12 @@ namespace ngraph ...@@ -71,12 +71,12 @@ namespace ngraph
size_t n_image_dimensions = arg0_shape.size() - 2; size_t n_image_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[1]; size_t n_input_channels = arg0_shape[1];
Shape input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_image_dimensions);
Shape input_batch_transform_movement_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_movement_strides(2 + n_image_dimensions, 1);
Shape input_batch_transform_padding_below(2 + n_image_dimensions, 0); CoordinateDiff input_batch_transform_padding_below(2 + n_image_dimensions, 0);
Shape input_batch_transform_padding_above(2 + n_image_dimensions, 0); CoordinateDiff input_batch_transform_padding_above(2 + n_image_dimensions, 0);
Shape input_batch_transform_dilation_strides(2 + n_image_dimensions, 1); Strides input_batch_transform_dilation_strides(2 + n_image_dimensions, 1);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
...@@ -87,8 +87,8 @@ namespace ngraph ...@@ -87,8 +87,8 @@ namespace ngraph
{ {
size_t window_dilation_stride = window_dilation_strides[i - 2]; size_t window_dilation_stride = window_dilation_strides[i - 2];
size_t window_movement_stride = window_movement_strides[i - 2]; size_t window_movement_stride = window_movement_strides[i - 2];
size_t below_pad = padding_below[i - 2]; std::ptrdiff_t below_pad = padding_below[i - 2];
size_t above_pad = padding_above[i - 2]; std::ptrdiff_t above_pad = padding_above[i - 2];
size_t image_dilation_stride = image_dilation_strides[i - 2]; size_t image_dilation_stride = image_dilation_strides[i - 2];
input_batch_transform_start[i] = window_movement_stride * out_coord[i]; input_batch_transform_start[i] = window_movement_stride * out_coord[i];
......
...@@ -59,8 +59,8 @@ namespace ngraph ...@@ -59,8 +59,8 @@ namespace ngraph
size_t n_image_dimensions = arg_shape.size() - 2; size_t n_image_dimensions = arg_shape.size() - 2;
Shape input_batch_transform_start(2 + n_image_dimensions); Coordinate input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Coordinate input_batch_transform_end(2 + n_image_dimensions);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
......
...@@ -54,13 +54,23 @@ namespace ngraph ...@@ -54,13 +54,23 @@ namespace ngraph
input_dilation[i] = padding_interior[i] + 1; input_dilation[i] = padding_interior[i] + 1;
} }
// Need to cast these to CoordinateDiff in order to make CoordinateTransform happy.
CoordinateDiff padding_below_signed;
CoordinateDiff padding_above_signed;
for (size_t i = 0; i < padding_below.size(); i++)
{
padding_below_signed.push_back(padding_below[i]);
padding_above_signed.push_back(padding_above[i]);
}
CoordinateTransform input_transform(arg0_shape, CoordinateTransform input_transform(arg0_shape,
input_start, input_start,
input_end, input_end,
input_strides, input_strides,
input_axis_order, input_axis_order,
padding_below, padding_below_signed,
padding_above, padding_above_signed,
input_dilation); input_dilation);
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -369,8 +369,8 @@ static shared_ptr<ngraph::Function> ...@@ -369,8 +369,8 @@ static shared_ptr<ngraph::Function>
node_js.at("window_movement_strides").get<vector<size_t>>(); node_js.at("window_movement_strides").get<vector<size_t>>();
auto window_dilation_strides = auto window_dilation_strides =
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<size_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<size_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
node = make_shared<op::Convolution>(args[0], node = make_shared<op::Convolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment