Unverified Commit c682fbf4 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Image batch dilation for convolution (#363)

Sub-PR: image dilation tests (#362) via @adstraw 
parent 74850150
...@@ -30,15 +30,17 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -30,15 +30,17 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& source_padding_below, const Shape& target_padding_below,
const Shape& source_padding_above) const Shape& target_padding_above,
const Strides& target_dilation_strides)
: m_source_shape(source_shape) : m_source_shape(source_shape)
, m_source_start_corner(source_start_corner) , m_source_start_corner(source_start_corner)
, m_source_end_corner(source_end_corner) , m_source_end_corner(source_end_corner)
, m_source_strides(source_strides) , m_source_strides(source_strides)
, m_source_axis_order(source_axis_order) , m_source_axis_order(source_axis_order)
, m_source_padding_below(source_padding_below) , m_target_padding_below(target_padding_below)
, m_source_padding_above(source_padding_above) , m_target_padding_above(target_padding_above)
, m_target_dilation_strides(target_dilation_strides)
{ {
m_n_axes = source_shape.size(); m_n_axes = source_shape.size();
...@@ -65,16 +67,21 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -65,16 +67,21 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
throw std::domain_error( throw std::domain_error(
"Source axis order does not have the same number of axes as the source space shape"); "Source axis order does not have the same number of axes as the source space shape");
} }
if (m_n_axes != source_padding_below.size()) if (m_n_axes != target_padding_below.size())
{ {
throw std::domain_error( throw std::domain_error(
"Padding-below shape does not have the same number of axes as the source space shape"); "Padding-below shape does not have the same number of axes as the source space shape");
} }
if (m_n_axes != source_padding_above.size()) if (m_n_axes != target_padding_above.size())
{ {
throw std::domain_error( throw std::domain_error(
"Padding-above shape does not have the same number of axes as the source space shape"); "Padding-above shape does not have the same number of axes as the source space shape");
} }
if (m_n_axes != target_dilation_strides.size())
{
throw std::domain_error(
"Target dilation strides do not have the same number of axes as the source shape");
}
AxisVector all_axes(m_n_axes); AxisVector all_axes(m_n_axes);
size_t n = 0; size_t n = 0;
...@@ -89,8 +96,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -89,8 +96,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_start_corner[i] >= if (source_start_corner[i] >= (source_shape[i] - 1) * target_dilation_strides[i] + 1 +
source_shape[i] + source_padding_below[i] + source_padding_above[i] && target_padding_below[i] + target_padding_above[i] &&
!(source_start_corner[i] == 0 && source_shape[i] == 0)) !(source_start_corner[i] == 0 && source_shape[i] == 0))
{ {
std::stringstream ss; std::stringstream ss;
...@@ -102,8 +109,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -102,8 +109,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
for (size_t i = 0; i < m_n_axes; i++) for (size_t i = 0; i < m_n_axes; i++)
{ {
if (source_end_corner[i] > if (source_end_corner[i] > (source_shape[i] - 1) * target_dilation_strides[i] + 1 +
source_shape[i] + source_padding_below[i] + source_padding_above[i]) target_padding_below[i] + target_padding_above[i])
{ {
std::stringstream ss; std::stringstream ss;
...@@ -123,6 +130,17 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -123,6 +130,17 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
} }
} }
for (size_t i = 0; i < m_n_axes; i++)
{
if (target_dilation_strides[i] == 0)
{
std::stringstream ss;
ss << "The target dilation stride is 0 at axis " << i;
throw std::domain_error(ss.str());
}
}
for (size_t axis = 0; axis < m_n_axes; axis++) for (size_t axis = 0; axis < m_n_axes; axis++)
{ {
m_target_shape.push_back(ceil_div(source_end_corner[source_axis_order[axis]] - m_target_shape.push_back(ceil_div(source_end_corner[source_axis_order[axis]] -
...@@ -131,6 +149,29 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -131,6 +149,29 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
} }
} }
Strides CoordinateTransform::default_strides(size_t n_axes)
{
return Strides(n_axes, 1);
}
CoordinateTransform::CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order,
const Shape& target_padding_below,
const Shape& target_padding_above)
: CoordinateTransform(source_shape,
source_start_corner,
source_end_corner,
source_strides,
source_axis_order,
target_padding_below,
target_padding_above,
default_strides(source_shape.size()))
{
}
Shape CoordinateTransform::default_padding(size_t n_axes) Shape CoordinateTransform::default_padding(size_t n_axes)
{ {
return Shape(n_axes, 0); return Shape(n_axes, 0);
...@@ -147,7 +188,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -147,7 +188,8 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
source_strides, source_strides,
source_axis_order, source_axis_order,
default_padding(source_shape.size()), default_padding(source_shape.size()),
default_padding(source_shape.size())) default_padding(source_shape.size()),
default_strides(source_shape.size()))
{ {
} }
...@@ -170,13 +212,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -170,13 +212,9 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
source_strides, source_strides,
default_axis_order(source_shape.size()), default_axis_order(source_shape.size()),
default_padding(source_shape.size()), default_padding(source_shape.size()),
default_padding(source_shape.size())) default_padding(source_shape.size()),
{ default_strides(source_shape.size()))
}
Strides CoordinateTransform::default_source_strides(size_t n_axes)
{ {
return AxisVector(n_axes, 1);
} }
CoordinateTransform::CoordinateTransform(const Shape& source_shape, CoordinateTransform::CoordinateTransform(const Shape& source_shape,
...@@ -185,10 +223,11 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -185,10 +223,11 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
: CoordinateTransform(source_shape, : CoordinateTransform(source_shape,
source_start_corner, source_start_corner,
source_end_corner, source_end_corner,
default_source_strides(source_shape.size()), default_strides(source_shape.size()),
default_axis_order(source_shape.size()), default_axis_order(source_shape.size()),
default_padding(source_shape.size()), default_padding(source_shape.size()),
default_padding(source_shape.size())) default_padding(source_shape.size()),
default_strides(source_shape.size()))
{ {
} }
...@@ -206,10 +245,11 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape) ...@@ -206,10 +245,11 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape)
: CoordinateTransform(source_shape, : CoordinateTransform(source_shape,
default_source_start_corner(source_shape.size()), default_source_start_corner(source_shape.size()),
default_source_end_corner(source_shape), default_source_end_corner(source_shape),
default_source_strides(source_shape.size()), default_strides(source_shape.size()),
default_axis_order(source_shape.size()), default_axis_order(source_shape.size()),
default_padding(source_shape.size()), default_padding(source_shape.size()),
default_padding(source_shape.size())) default_padding(source_shape.size()),
default_strides(source_shape.size()))
{ {
} }
...@@ -235,64 +275,79 @@ size_t CoordinateTransform::index(const Coordinate& c) const ...@@ -235,64 +275,79 @@ size_t CoordinateTransform::index(const Coordinate& c) const
} }
// Convert a target-space coordinate to a source-space coordinate. // Convert a target-space coordinate to a source-space coordinate.
Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c) const Coordinate CoordinateTransform::to_source_coordinate(const Coordinate& c_target) const
{ {
if (c.size() != m_n_axes) if (c_target.size() != m_n_axes)
{ {
throw std::domain_error("Coordinate rank does not match the coordinate transform rank"); throw std::domain_error(
"Target coordinate rank does not match the coordinate transform rank");
} }
Coordinate result(c.size()); Coordinate c_source(c_target.size());
for (size_t axis = 0; axis < m_n_axes; axis++) for (size_t target_axis = 0; target_axis < m_n_axes; target_axis++)
{ {
result[m_source_axis_order[axis]] = c[axis] * m_source_strides[axis] + size_t source_axis = m_source_axis_order[target_axis];
m_source_start_corner[axis] -
m_source_padding_below[axis]; size_t target_pos = c_target[target_axis];
size_t pos_destrided = target_pos * m_source_strides[source_axis];
size_t pos_deshifted = pos_destrided + m_source_start_corner[source_axis];
size_t pos_depadded = pos_deshifted - m_target_padding_below[target_axis];
size_t pos_dedilated = pos_depadded / m_target_dilation_strides[target_axis];
c_source[source_axis] = pos_dedilated;
} }
return result; return c_source;
} }
// Check if a coordinate is in bounds of the target space. // A point in the target space is considered not to have a source coordinate if it was inserted due to
bool CoordinateTransform::in_bounds(const Coordinate& c) const // padding or dilation, or if it is out of the bounds of the target space.
bool CoordinateTransform::has_source_coordinate(const Coordinate& c_target) const
{ {
if (c.size() != m_n_axes) if (c_target.size() != m_n_axes)
{ {
return false; throw std::domain_error(
"Target coordinate rank does not match the coordinate transform rank");
} }
for (size_t axis = 0; axis < m_n_axes; axis++) for (size_t target_axis = 0; target_axis < m_n_axes; target_axis++)
{ {
if (c[axis] < m_target_shape[axis] || c[axis] >= m_target_shape[axis]) // Is this coordinate out of bounds of the target space?
if (c_target[target_axis] >= m_target_shape[target_axis])
{ {
return false; return false;
} }
}
return true; // The rest of this is a replay of the corresponding logic in `to_source_coordinate`, with
} // bounds and divisibility checking.
size_t source_axis = m_source_axis_order[target_axis];
// Check if a coordinate corresponds to one of the padding elements that has been added to size_t target_pos = c_target[target_axis];
// the source space. size_t pos_destrided = target_pos * m_source_strides[source_axis];
bool CoordinateTransform::in_padding(const Coordinate& c) const size_t pos_deshifted = pos_destrided + m_source_start_corner[source_axis];
{
if (c.size() != m_n_axes)
{
throw std::domain_error("Coordinate rank does not match the coordinate transform rank");
}
for (size_t axis = 0; axis < m_n_axes; axis++) // If we are in the below-padding or the above-padding.
{ if (pos_deshifted < m_target_padding_below[target_axis])
size_t padded_pos = c[axis] * m_source_strides[axis] + m_source_start_corner[axis]; {
if (padded_pos < m_source_padding_below[axis] || return false;
padded_pos >= m_source_padding_below[axis] + m_source_shape[axis]) }
size_t pos_depadded = pos_deshifted - m_target_padding_below[target_axis];
// If we are in the above-padding, we have no source coordinate.
if (pos_depadded >=
((m_source_shape[source_axis] - 1) * m_target_dilation_strides[target_axis]) + 1)
{ {
return true; return false;
}
// If we are in a dilation gap, we have no source coordinate.
if (pos_depadded % m_target_dilation_strides[target_axis] != 0)
{
return false;
} }
} }
return false; return true;
} }
Coordinate CoordinateTransform::get_target_shape() const Coordinate CoordinateTransform::get_target_shape() const
......
...@@ -31,8 +31,17 @@ namespace ngraph ...@@ -31,8 +31,17 @@ namespace ngraph
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& source_padding_below, const Shape& target_padding_below,
const Shape& source_padding_above); const Shape& target_padding_above,
const Strides& source_dilation_strides);
CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order,
const Shape& target_padding_below,
const Shape& target_padding_above);
CoordinateTransform(const Shape& source_shape, CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner, const Coordinate& source_start_corner,
...@@ -52,8 +61,7 @@ namespace ngraph ...@@ -52,8 +61,7 @@ namespace ngraph
CoordinateTransform(const Shape& source_shape); CoordinateTransform(const Shape& source_shape);
size_t index(const Coordinate& c) const; size_t index(const Coordinate& c) const;
bool in_bounds(const Coordinate& c) const; bool has_source_coordinate(const Coordinate& c) const;
bool in_padding(const Coordinate& c) const;
Coordinate to_source_coordinate(const Coordinate& c) const; Coordinate to_source_coordinate(const Coordinate& c) const;
Coordinate get_target_shape() const; Coordinate get_target_shape() const;
...@@ -62,6 +70,7 @@ namespace ngraph ...@@ -62,6 +70,7 @@ namespace ngraph
Coordinate get_source_end_corner() { return m_source_end_corner; } Coordinate get_source_end_corner() { return m_source_end_corner; }
Strides get_source_strides() { return m_source_strides; } Strides get_source_strides() { return m_source_strides; }
AxisVector get_source_axis_order() { return m_source_axis_order; } AxisVector get_source_axis_order() { return m_source_axis_order; }
Strides get_target_dilation_strides() { return m_target_dilation_strides; }
class Iterator class Iterator
{ {
public: public:
...@@ -86,9 +95,9 @@ namespace ngraph ...@@ -86,9 +95,9 @@ namespace ngraph
Iterator end() noexcept { return Iterator(m_target_shape, true); } Iterator end() noexcept { return Iterator(m_target_shape, true); }
private: private:
size_t index_source(const Coordinate& c) const; size_t index_source(const Coordinate& c) const;
static Strides default_strides(size_t n_axes);
static Shape default_padding(size_t n_axes); static Shape default_padding(size_t n_axes);
static AxisVector default_axis_order(size_t n_axes); static AxisVector default_axis_order(size_t n_axes);
static Strides default_source_strides(size_t n_axes);
static Coordinate default_source_start_corner(size_t n_axes); static Coordinate default_source_start_corner(size_t n_axes);
static Coordinate default_source_end_corner(const Shape& source_shape); static Coordinate default_source_end_corner(const Shape& source_shape);
...@@ -97,8 +106,9 @@ namespace ngraph ...@@ -97,8 +106,9 @@ namespace ngraph
Shape m_source_end_corner; Shape m_source_end_corner;
Strides m_source_strides; Strides m_source_strides;
AxisVector m_source_axis_order; AxisVector m_source_axis_order;
Shape m_source_padding_below; Shape m_target_padding_below;
Shape m_source_padding_above; Shape m_target_padding_above;
Strides m_target_dilation_strides;
Shape m_target_shape; Shape m_target_shape;
size_t m_n_axes; size_t m_n_axes;
......
...@@ -23,12 +23,14 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -23,12 +23,14 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
const Strides& image_dilation_strides)
: RequiresTensorViewArgs("Convolution", {image_batch, filters}) : RequiresTensorViewArgs("Convolution", {image_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_image_dilation_strides(image_dilation_strides)
{ {
auto& image_batch_shape = get_inputs().at(0).get_shape(); auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& filters_shape = get_inputs().at(1).get_shape(); auto& filters_shape = get_inputs().at(1).get_shape();
...@@ -77,7 +79,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -77,7 +79,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
} }
// //
// Make sure window movement strides and window dilation strades have same rank as Di. // Make sure window movement strides, window dilation strides, and image dilation strides
// have same rank as Di.
// //
if (m_window_movement_strides.size() != m_image_dimension_count) if (m_window_movement_strides.size() != m_image_dimension_count)
{ {
...@@ -91,6 +94,12 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -91,6 +94,12 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
"Convolution window dilation stride rank does not match number of image dimensions."); "Convolution window dilation stride rank does not match number of image dimensions.");
} }
if (m_image_dilation_strides.size() != m_image_dimension_count)
{
throw ngraph_error(
"Convolution image dilation stride rank does not match number of image dimensions.");
}
// //
// Make sure padding-below and padding-above shapes have same rank as Di. // Make sure padding-below and padding-above shapes have same rank as Di.
// //
...@@ -107,28 +116,36 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -107,28 +116,36 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
} }
// //
// Extract input image shape Di and make sure all dimensions are larger than 0 after padding. // Extract input image shape Di and make sure all dimensions are larger than 0 after padding and dilation.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < m_image_dimension_count; i++)
{ {
m_input_image_shape.push_back(image_batch_shape[1 + 1 + i]); if (image_dilation_strides[i] == 0)
m_padded_input_image_shape.push_back(padding_below[i] + image_batch_shape[1 + 1 + i] + {
padding_above[i]); throw ngraph_error("Convolution image dilation stride is zero.");
}
if (m_padded_input_image_shape[i] == 0) size_t dim_size = image_batch_shape[1 + 1 + i];
m_input_image_physical_shape.push_back(dim_size);
size_t dilated_dim_size = (dim_size - 1) * image_dilation_strides[i] + 1;
size_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
m_input_image_virtual_shape.push_back(padded_dilated_dim_size);
if (m_input_image_virtual_shape[i] == 0)
{ {
throw ngraph_error("Convolution input image dimension is zero even with padding."); throw ngraph_error(
"Convolution input image dimension after dilation is zero even with padding.");
} }
} }
// //
// Extract the virtual shape Wv of the convolution window, *not* including dilation, from the filter dimensions. // Extract the physical shape Wp of the convolution window, *not* including dilation, from the filter dimensions.
// At the same time, make sure window shape dimensions are all larger than 0. // At the same time, make sure window shape dimensions are all larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < m_image_dimension_count; i++)
{ {
m_window_virtual_shape.push_back(filters_shape[1 + 1 + i]); m_window_physical_shape.push_back(filters_shape[1 + 1 + i]);
if (m_window_virtual_shape[i] == 0) if (m_window_physical_shape[i] == 0)
{ {
throw ngraph_error("Convolution window shape has a zero-length axis."); throw ngraph_error("Convolution window shape has a zero-length axis.");
} }
...@@ -145,10 +162,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -145,10 +162,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
throw ngraph_error("Convolution window axis dilation stride is zero."); throw ngraph_error("Convolution window axis dilation stride is zero.");
} }
m_window_physical_shape.push_back( m_window_virtual_shape.push_back(
(m_window_virtual_shape[i] - 1) * m_window_dilation_strides[i] + 1); (m_window_physical_shape[i] - 1) * m_window_dilation_strides[i] + 1);
if (m_window_physical_shape[i] > m_padded_input_image_shape[i]) if (m_window_virtual_shape[i] > m_input_image_virtual_shape[i])
{ {
throw ngraph_error( throw ngraph_error(
"Convolution window after dilation is larger than the image even with padding."); "Convolution window after dilation is larger than the image even with padding.");
...@@ -165,7 +182,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -165,7 +182,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
throw ngraph_error("Convolution window axis movement stride is zero."); throw ngraph_error("Convolution window axis movement stride is zero.");
} }
m_output_image_shape.push_back( m_output_image_shape.push_back(
ceil_div(m_padded_input_image_shape[i] - m_window_physical_shape[i] + 1, ceil_div(m_input_image_virtual_shape[i] - m_window_virtual_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
...@@ -180,7 +197,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -180,7 +197,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape); set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
} }
Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -190,23 +207,26 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) ...@@ -190,23 +207,26 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
"Convolution image batch input must have rank of at least 3 (one batch axis, one " "Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension)."); "input-channel axis, at least one image dimension).");
} }
return Shape(image_batch_shape.size() - 2, 0); return Strides(image_batch_shape.size() - 2, 1);
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above)
: Convolution(image_batch, : Convolution(image_batch,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
default_padding(image_batch), padding_below,
default_padding(image_batch)) padding_above,
default_strides(image_batch))
{ {
} }
Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batch) Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -216,7 +236,20 @@ Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batc ...@@ -216,7 +236,20 @@ Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batc
"Convolution image batch input must have rank of at least 3 (one batch axis, one " "Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension)."); "input-channel axis, at least one image dimension).");
} }
return Strides(image_batch_shape.size() - 2, 1); return Shape(image_batch_shape.size() - 2, 0);
}
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides)
: Convolution(image_batch,
filters,
window_movement_strides,
window_dilation_strides,
default_padding(image_batch),
default_padding(image_batch))
{
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
...@@ -254,7 +287,8 @@ std::shared_ptr<Node> ...@@ -254,7 +287,8 @@ std::shared_ptr<Node>
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above,
m_image_dilation_strides);
} }
bool op::Convolution::is_functionally_identical(const Node& other) const bool op::Convolution::is_functionally_identical(const Node& other) const
...@@ -265,9 +299,13 @@ bool op::Convolution::is_functionally_identical(const Node& other) const ...@@ -265,9 +299,13 @@ bool op::Convolution::is_functionally_identical(const Node& other) const
const Convolution& rhs = dynamic_cast<const Convolution&>(other); const Convolution& rhs = dynamic_cast<const Convolution&>(other);
rc &= m_window_movement_strides == rhs.m_window_movement_strides; rc &= m_window_movement_strides == rhs.m_window_movement_strides;
rc &= m_window_dilation_strides == rhs.m_window_dilation_strides; rc &= m_window_dilation_strides == rhs.m_window_dilation_strides;
rc &= m_padding_below == rhs.m_padding_below;
rc &= m_padding_above == rhs.m_padding_above;
rc &= m_image_dilation_strides == rhs.m_image_dilation_strides;
rc &= m_input_channel_count == rhs.m_input_channel_count; rc &= m_input_channel_count == rhs.m_input_channel_count;
rc &= m_output_channel_count == rhs.m_output_channel_count; rc &= m_output_channel_count == rhs.m_output_channel_count;
rc &= m_input_image_shape == rhs.m_input_image_shape; rc &= m_input_image_physical_shape == rhs.m_input_image_physical_shape;
rc &= m_input_image_virtual_shape == rhs.m_input_image_virtual_shape;
rc &= m_output_image_shape == rhs.m_output_image_shape; rc &= m_output_image_shape == rhs.m_output_image_shape;
rc &= m_window_physical_shape == rhs.m_window_physical_shape; rc &= m_window_physical_shape == rhs.m_window_physical_shape;
rc &= m_window_virtual_shape == rhs.m_window_virtual_shape; rc &= m_window_virtual_shape == rhs.m_window_virtual_shape;
......
...@@ -27,28 +27,29 @@ namespace ngraph ...@@ -27,28 +27,29 @@ namespace ngraph
/// 1. <i>(the image batch)</i> a tensor of shape \f$(N,C_\textit{in},d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is the batch size /// 1. <i>(the image batch)</i> a tensor of shape \f$(N,C_\textit{in},d_1,\dots,d_n)\f$ where \f$n > 0\f$, every \f$d_i > 0\f$, and where \f$N\f$ is the batch size
/// (number of images) and \f$C_\textit{in} > 0\f$ is the number of input channels (sometimes called features); and /// (number of images) and \f$C_\textit{in} > 0\f$ is the number of input channels (sometimes called features); and
/// 2. <i>(the filters)</i> a tensor of shape \f$(C_\textit{out},C_\textit{in},d^f_1,\dots,d^f_n)\f$, where \f$C_\textit{out} > 0\f$ is the number of output channels /// 2. <i>(the filters)</i> a tensor of shape \f$(C_\textit{out},C_\textit{in},d^f_1,\dots,d^f_n)\f$, where \f$C_\textit{out} > 0\f$ is the number of output channels
/// (sometimes called features) and \f$(d^f_1,\dots,d^f_n)\f$ are the filter dimensions. It is required that for all \f$i\f$, \f$0 < l_i(d^f_i - 1) + 1 \le d_i\f$. /// (sometimes called features) and \f$(d^f_1,\dots,d^f_n)\f$ are the filter dimensions. It is required that for all \f$i\f$, \f$0 < l_i(d^f_i - 1) + 1 \le (d_i - 1)*g_i + 1\f$.
/// (See below for the definition of the dilation \f$l_i\f$); /// (See below for the definition of the window dilation \f$l_i\f$ and the image dilation \f$g_i\f$);
/// ///
/// and four optional parameters: /// and five optional parameters:
/// ///
/// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$ (default is all ones), /// 3. <i>(the window movement strides)</i> a vector of positive integers \f$(s_1,\dots,s_n)\f$ (default is all ones),
/// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$ (default is all ones), /// 4. <i>(the window dilation strides)</i> a vector of positive integers \f$(l_1,\dots,l_n)\f$ (default is all ones),
/// 5. <i>(the padding below)</i> a vector of non-negative integers \f$(p_1,\dots,p_n)\f$ (default is all zeros), and /// 5. <i>(the padding below)</i> a vector of non-negative integers \f$(p_1,\dots,p_n)\f$ (default is all zeros),
/// 6. <i>(the padding above)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all zeros). /// 6. <i>(the padding above)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all zeros), and
/// 7. <i>(the image dilation strides)</i> a vector of non-negative integers \f$(q_1,\dots,q_n)\f$ (default is all ones).
/// ///
/// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{d_i + p_i + q_i - l_i(d^f_i - 1)}{s_i} \rceil\f$. /// The output has the shape \f$(N,C_\textit{out},d'_1,\dots,d'_n)\f$, where \f$d'_n = \lceil \frac{(d_i - 1) * g_i + 1 + p_i + q_i - l_i(d^f_i - 1)}{s_i} \rceil\f$.
/// ///
/// Given an input image batch tensor \f$T_\textit{in}\f$, first define the <i>padded input tensor</i> \f$T_\textit{pad}\f$, with shape \f$(N,C_\textit{in},d_1+p_1+q+1,\dots,d_n+p_n+q_n)\f$, as follows: /// Given an input image batch tensor \f$T_\textit{in}\f$, first define the <i>transformed input tensor</i> \f$T_\textit{trans}\f$, with shape \f$(N,C_\textit{in},(d_1 - 1)*g_1+1+p_1+q_1,\dots,(d_n - 1)*g_n+1+p_n+q_n)\f$, as follows:
/// ///
/// \f[ /// \f[
/// T_\textit{pad}[a,c,i_1,\dots,i_n] = T[a,c,i_1 - p_1,\dots,i_n - p_n] \text{ if for all }k, p_k \le i_k \lt p_k + d_k, \text{ else } 0 /// T_\textit{trans}[a,c,i_1,\dots,i_n] = T[a,c,(i_1 - p_1)/g_1,\dots,(i_n - p_n)/g_n] \text{ if for all }k, g_k evenly divides (i_k - p_k) and p_k \le i_k \lt p_k + (d_k - 1)*g_k + 1, \text{ else } 0
/// \f] /// \f]
/// ///
/// then, given an input filter tensor \f$T_\textit{filt}\f$, the output tensor \f$T_\textit{out}\f$ is defined by the equation. /// then, given an input filter tensor \f$T_\textit{filt}\f$, the output tensor \f$T_\textit{out}\f$ is defined by the equation.
/// ///
/// \f[ /// \f[
/// T_\textit{out}[a,c_\textit{out},i_1,\dots,i_n] = \sum_{c_\textit{in}=0,j_1=0,\dots,j_n=0}^{c_\textit{in}=C_\textit{in}-1,j_1=d^f_1-1,\dots,j_n=d^f_n-1} (T_\textit{filt}[c_\textit{out},c_\textit{in},j_1,\dots,j_n] \cdot T_\textit{pad}[a,c_\textit{in},s_1i_1+l_1j_1,\dots,s_ni_n+l_nj_n]) /// T_\textit{out}[a,c_\textit{out},i_1,\dots,i_n] = \sum_{c_\textit{in}=0,j_1=0,\dots,j_n=0}^{c_\textit{in}=C_\textit{in}-1,j_1=d^f_1-1,\dots,j_n=d^f_n-1} (T_\textit{filt}[c_\textit{out},c_\textit{in},j_1,\dots,j_n] \cdot T_\textit{trans}[a,c_\textit{in},s_1i_1+l_1j_1,\dots,s_ni_n+l_nj_n])
/// \f] /// \f]
/// ///
class Convolution : public RequiresTensorViewArgs class Convolution : public RequiresTensorViewArgs
...@@ -62,6 +63,23 @@ namespace ngraph ...@@ -62,6 +63,23 @@ namespace ngraph
/// \param window_dilation_strides The window dilation strides. /// \param window_dilation_strides The window dilation strides.
/// \param padding_below The padding-below sizes. /// \param padding_below The padding-below sizes.
/// \param padding_above The padding-above sizes. /// \param padding_above The padding-above sizes.
/// \param image_dilation_strides The image dilation strides.
Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above,
const Strides& image_dilation_strides);
/// \brief Constructs a batched convolution operation with no image dilation (i.e., all image dilation strides are 1).
///
/// \param image_batch The node producing the input image batch tensor.
/// \param filters The node producing the filters tensor.
/// \param window_movement_strides The window movement strides.
/// \param window_dilation_strides The window dilation strides.
/// \param padding_below The padding-below sizes.
/// \param padding_above The padding-above sizes.
Convolution(const std::shared_ptr<Node>& image_batch, Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -69,7 +87,7 @@ namespace ngraph ...@@ -69,7 +87,7 @@ namespace ngraph
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above);
/// \brief Constructs a batched convolution operation with no padding (i.e., padding above and below are 0 everywhere). /// \brief Constructs a batched convolution operation with no padding or image dilation (i.e., padding above and below are 0 everywhere, and all image dilation strides are 1).
/// ///
/// \param image_batch The node producing the input image batch tensor. /// \param image_batch The node producing the input image batch tensor.
/// \param filters The node producing the filters tensor. /// \param filters The node producing the filters tensor.
...@@ -80,7 +98,7 @@ namespace ngraph ...@@ -80,7 +98,7 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides); const Strides& window_dilation_strides);
/// \brief Constructs a batched convolution operation with no window dilation (i.e., all dilation strides are 1). /// \brief Constructs a batched convolution operation with no window dilation, padding, or image dilation (i.e., padding above and below are 0 everywhere, and all window/image dilation strides are 1).
/// ///
/// \param image_batch The node producing the input image batch tensor. /// \param image_batch The node producing the input image batch tensor.
/// \param filters The node producing the filters tensor. /// \param filters The node producing the filters tensor.
...@@ -89,7 +107,7 @@ namespace ngraph ...@@ -89,7 +107,7 @@ namespace ngraph
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides); const Strides& window_movement_strides);
/// \brief Constructs a batched convolution operation with no window dilation or movement stride (i.e., all dilation and movement strides are 1). /// \brief Constructs a batched convolution operation with no window dilation or movement stride (i.e., padding above and below are 0 everywhere, and all window/image dilation strides and window movement strides are 1).
/// ///
/// \param image_batch The node producing the input image batch tensor. /// \param image_batch The node producing the input image batch tensor.
/// \param filters The node producing the filters tensor. /// \param filters The node producing the filters tensor.
...@@ -107,14 +125,22 @@ namespace ngraph ...@@ -107,14 +125,22 @@ namespace ngraph
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
/// \return The padding-above sizes. /// \return The padding-above sizes.
const Strides& get_padding_above() const { return m_padding_above; } const Strides& get_padding_above() const { return m_padding_above; }
/// \return The input image dilation strides.
const Strides& get_image_dilation_strides() const { return m_image_dilation_strides; }
/// \return The number of input channels. /// \return The number of input channels.
size_t get_input_channel_count() const { return m_input_channel_count; } size_t get_input_channel_count() const { return m_input_channel_count; }
/// \return The number of output channels. /// \return The number of output channels.
size_t get_output_channel_count() const { return m_output_channel_count; } size_t get_output_channel_count() const { return m_output_channel_count; }
/// \return The input image shape, not including padding. /// \return The physical input image shape, not including padding and dilation.
const Shape& get_input_image_shape() const { return m_input_image_shape; } const Shape& get_input_image_physical_shape() const
/// \return The input image shape, including padding. {
const Shape& get_padded_input_image_shape() const { return m_padded_input_image_shape; } return m_input_image_physical_shape;
}
/// \return The virtual input image shape, including padding and dilation.
const Shape& get_input_image_virtual_shape() const
{
return m_input_image_virtual_shape;
}
/// \return The output image shape. /// \return The output image shape.
const Shape& get_output_image_shape() const { return m_output_image_shape; } const Shape& get_output_image_shape() const { return m_output_image_shape; }
/// \return The physical window shape. /// \return The physical window shape.
...@@ -132,12 +158,13 @@ namespace ngraph ...@@ -132,12 +158,13 @@ namespace ngraph
Strides m_window_dilation_strides; Strides m_window_dilation_strides;
Shape m_padding_below; Shape m_padding_below;
Shape m_padding_above; Shape m_padding_above;
Strides m_image_dilation_strides;
// TODO: Some of these values should probably be computed dynamically rather than stored here. // TODO: Some or all of these values should probably be computed dynamically rather than stored here.
size_t m_input_channel_count; size_t m_input_channel_count;
size_t m_output_channel_count; size_t m_output_channel_count;
Shape m_input_image_shape; Shape m_input_image_physical_shape;
Shape m_padded_input_image_shape; Shape m_input_image_virtual_shape;
Shape m_output_image_shape; Shape m_output_image_shape;
Shape m_window_physical_shape; Shape m_window_physical_shape;
Shape m_window_virtual_shape; Shape m_window_virtual_shape;
...@@ -145,8 +172,8 @@ namespace ngraph ...@@ -145,8 +172,8 @@ namespace ngraph
size_t m_image_dimension_count; size_t m_image_dimension_count;
private: private:
static Shape default_padding(const std::shared_ptr<Node>& image_batch);
static Strides default_strides(const std::shared_ptr<Node>& image_batch); static Strides default_strides(const std::shared_ptr<Node>& image_batch);
static Shape default_padding(const std::shared_ptr<Node>& image_batch);
}; };
} }
} }
...@@ -1678,7 +1678,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n, ...@@ -1678,7 +1678,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n,
m_out << " {" << join(convolution->get_window_dilation_strides()) m_out << " {" << join(convolution->get_window_dilation_strides())
<< "},\n"; << "},\n";
m_out << " {" << join(convolution->get_padding_below()) << "},\n"; m_out << " {" << join(convolution->get_padding_below()) << "},\n";
m_out << " {" << join(convolution->get_padding_above()) << "});\n"; m_out << " {" << join(convolution->get_padding_above()) << "},\n";
m_out << " {" << join(convolution->get_image_dilation_strides())
<< "});\n";
} }
void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n, void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n,
......
...@@ -303,7 +303,8 @@ private: ...@@ -303,7 +303,8 @@ private:
c->get_window_movement_strides(), c->get_window_movement_strides(),
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above()); c->get_padding_above(),
c->get_image_dilation_strides());
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
...@@ -36,7 +36,8 @@ namespace ngraph ...@@ -36,7 +36,8 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
const Strides& image_dilation_strides)
{ {
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
...@@ -64,17 +65,18 @@ namespace ngraph ...@@ -64,17 +65,18 @@ namespace ngraph
// //
// (1,l_1,...,l_n). // (1,l_1,...,l_n).
// //
// Note that we are iterating within the *padded* image batch, so further down we must check // Note that we are iterating within the *padded* and *dilated* image batch, so further
// the current coordinate is in the padding. // down we must check the current coordinate is in the padding or dilation gap.
size_t n_image_dimensions = arg0_shape.size() - 2; size_t n_image_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[1]; size_t n_input_channels = arg0_shape[1];
Shape input_batch_transform_start(2 + n_image_dimensions); Shape input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Shape input_batch_transform_end(2 + n_image_dimensions);
Shape input_batch_transform_strides(2 + n_image_dimensions, 1); Shape input_batch_transform_movement_strides(2 + n_image_dimensions, 1);
Shape input_batch_padding_below(2 + n_image_dimensions, 0); Shape input_batch_transform_padding_below(2 + n_image_dimensions, 0);
Shape input_batch_padding_above(2 + n_image_dimensions, 0); Shape input_batch_transform_padding_above(2 + n_image_dimensions, 0);
Shape input_batch_transform_dilation_strides(2 + n_image_dimensions, 1);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
...@@ -83,32 +85,37 @@ namespace ngraph ...@@ -83,32 +85,37 @@ namespace ngraph
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_image_dimensions + 2; i++)
{ {
size_t dilation_stride = window_dilation_strides[i - 2]; size_t window_dilation_stride = window_dilation_strides[i - 2];
size_t movement_stride = window_movement_strides[i - 2]; size_t window_movement_stride = window_movement_strides[i - 2];
size_t below_pad = padding_below[i - 2]; size_t below_pad = padding_below[i - 2];
size_t above_pad = padding_above[i - 2]; size_t above_pad = padding_above[i - 2];
size_t image_dilation_stride = image_dilation_strides[i - 2];
input_batch_transform_start[i] = movement_stride * out_coord[i];
input_batch_transform_end[i] = input_batch_transform_start[i] + input_batch_transform_start[i] = window_movement_stride * out_coord[i];
(arg1_shape[i] - 1) * dilation_stride + 1; input_batch_transform_end[i] =
input_batch_transform_strides[i] = dilation_stride; input_batch_transform_start[i] +
input_batch_padding_below[i] = below_pad; (arg1_shape[i] - 1) * window_dilation_stride + 1;
input_batch_padding_above[i] = above_pad; input_batch_transform_movement_strides[i] = window_dilation_stride;
input_batch_transform_padding_below[i] = below_pad;
input_batch_transform_padding_above[i] = above_pad;
input_batch_transform_dilation_strides[i] = image_dilation_stride;
} }
AxisVector input_batch_axis_order(2 + n_image_dimensions); AxisVector input_batch_transform_axis_order(2 + n_image_dimensions);
size_t n = 0; size_t n = 0;
std::generate(input_batch_axis_order.begin(), std::generate(input_batch_transform_axis_order.begin(),
input_batch_axis_order.end(), input_batch_transform_axis_order.end(),
[&n]() -> size_t { return n++; }); [&n]() -> size_t { return n++; });
CoordinateTransform input_batch_transform(arg0_shape, CoordinateTransform input_batch_transform(
input_batch_transform_start, arg0_shape,
input_batch_transform_end, input_batch_transform_start,
input_batch_transform_strides, input_batch_transform_end,
input_batch_axis_order, input_batch_transform_movement_strides,
input_batch_padding_below, input_batch_transform_axis_order,
input_batch_padding_above); input_batch_transform_padding_below,
input_batch_transform_padding_above,
input_batch_transform_dilation_strides);
// Simultaneously with iterating I, for the filters we need to iterate the coordinate: // Simultaneously with iterating I, for the filters we need to iterate the coordinate:
// //
...@@ -151,10 +158,13 @@ namespace ngraph ...@@ -151,10 +158,13 @@ namespace ngraph
{ {
const Coordinate& input_batch_coord = *input_it; const Coordinate& input_batch_coord = *input_it;
const Coordinate& filter_coord = *filter_it; const Coordinate& filter_coord = *filter_it;
T v = input_batch_transform.in_padding(input_batch_coord)
? 0 T v = input_batch_transform.has_source_coordinate(input_batch_coord)
: arg0[input_batch_transform.index(input_batch_coord)]; ? arg0[input_batch_transform.index(input_batch_coord)]
: 0;
result += v * arg1[filter_transform.index(filter_coord)]; result += v * arg1[filter_transform.index(filter_coord)];
++input_it; ++input_it;
++filter_it; ++filter_it;
} }
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -67,9 +67,10 @@ def tuple_times(t1,t2): ...@@ -67,9 +67,10 @@ def tuple_times(t1,t2):
# img_batch : [N ][Ci][D1]...[Dn], n > 0 # img_batch : [N ][Ci][D1]...[Dn], n > 0
# filter : [Co][Ci][W1]...[Wn] # filter : [Co][Ci][W1]...[Wn]
# move_strides = (s1,...,sn) # move_strides = (s1,...,sn)
# dilation_strides = (l1,...,ln) # filter_dilation = (l1,...,ln)
# below_pads = (p1,...,pn) # below_pads = (p1,...,pn)
# above_pads = (q1,...,qn) # above_pads = (q1,...,qn)
# image_dilation = (g1,...,gn)
# #
# Returns: # Returns:
# output_batch : [N ][Co][D'1]...[D'n] # output_batch : [N ][Co][D'1]...[D'n]
...@@ -77,12 +78,31 @@ def tuple_times(t1,t2): ...@@ -77,12 +78,31 @@ def tuple_times(t1,t2):
# Where the D's are computed according to TensorFlow-style "valid" convolution rules, but *after* padding. # Where the D's are computed according to TensorFlow-style "valid" convolution rules, but *after* padding.
# See https://www.tensorflow.org/api_docs/python/tf/nn/convolution. # See https://www.tensorflow.org/api_docs/python/tf/nn/convolution.
# #
def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pads, above_pads): def convolution_ref(img_batch, filter, move_strides, filter_dilation, below_pads, above_pads, image_dilation):
assert(len(img_batch.shape) == len(filter.shape)) assert(len(img_batch.shape) == len(filter.shape))
assert(len(img_batch.shape) > 2) assert(len(img_batch.shape) > 2)
assert(img_batch.shape[1] == filter.shape[1]) assert(img_batch.shape[1] == filter.shape[1])
assert(len(move_strides) == len(img_batch.shape) - 2) assert(len(move_strides) == len(img_batch.shape) - 2)
assert(len(dilation_strides) == len(img_batch.shape) - 2) assert(len(filter_dilation) == len(img_batch.shape) - 2)
assert(len(image_dilation) == len(img_batch.shape) - 2)
# dilate the input batch
new_img_shape = (np.array(img_batch.shape[2:]) - 1) * image_dilation + 1
new_img_batch_shape = list(np.array(img_batch.shape[:2])) + list(new_img_shape)
new_img_batch = np.zeros(new_img_batch_shape)
for n in range(0, new_img_batch_shape[0]) :
for c in range(0, new_img_batch_shape[1]) :
if new_img_batch.ndim == 4:
new_img_batch[n, c, 0::image_dilation[0], 0::image_dilation[1]] = img_batch[n][c]
elif new_img_batch.ndim == 5:
new_img_batch[n, c, 0::image_dilation[0], 0::image_dilation[1], 0::image_dilation[2]] = img_batch[n][c]
elif new_img_batch.ndim == 6:
new_img_batch[n, c, 0::image_dilation[0], 0::image_dilation[1], 0::image_dilation[2], 0::image_dilation[3]] = img_batch[n][c]
else:
assert(False)
img_batch = new_img_batch
# Pad the input batch. # Pad the input batch.
below_pads = (0,0) + below_pads # Have to add values for the image and channel dims. below_pads = (0,0) + below_pads # Have to add values for the image and channel dims.
...@@ -97,13 +117,13 @@ def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pad ...@@ -97,13 +117,13 @@ def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pad
# This is not used in computation but we will calculate it for a check to make sure the window fits. # This is not used in computation but we will calculate it for a check to make sure the window fits.
window_physical_shape = [] window_physical_shape = []
for (d_in,d_virt,dil) in zip(input_img_shape,window_virtual_shape,dilation_strides): for (d_in,d_virt,dil) in zip(input_img_shape,window_virtual_shape,filter_dilation):
d_phys = (d_virt - 1) * dil + 1 d_phys = (d_virt - 1) * dil + 1
assert(d_phys <= input_img_shape) assert(d_phys <= input_img_shape)
window_physical_shape.append(d_phys) window_physical_shape.append(d_phys)
output_img_shape = [] # D'1,...,D'n output_img_shape = [] # D'1,...,D'n
for (d_in,d_win,dil,mov) in zip (input_img_shape,window_virtual_shape,dilation_strides,move_strides): for (d_in,d_win,dil,mov) in zip (input_img_shape,window_virtual_shape,filter_dilation,move_strides):
d_out = int(math.ceil((float(d_in) - (float(d_win) - 1.0) * float(dil))/float(mov))) # Formula is taken from TF's definition for VALID convolution. d_out = int(math.ceil((float(d_in) - (float(d_win) - 1.0) * float(dil))/float(mov))) # Formula is taken from TF's definition for VALID convolution.
assert(d_out > 0) assert(d_out > 0)
output_img_shape.append(d_out) output_img_shape.append(d_out)
...@@ -126,7 +146,7 @@ def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pad ...@@ -126,7 +146,7 @@ def convolution_ref(img_batch, filter, move_strides, dilation_strides, below_pad
ci, filter_pos = filter_index[0], filter_index[1:] ci, filter_pos = filter_index[0], filter_index[1:]
# Build up the coordinate within the space N,Ci,D1,...,Dn that we need to read from in the input batch. # Build up the coordinate within the space N,Ci,D1,...,Dn that we need to read from in the input batch.
input_index = (img,ci) + (tuple_plus(tuple_times(output_pos,move_strides),tuple_times(filter_pos,dilation_strides))) input_index = (img,ci) + (tuple_plus(tuple_times(output_pos,move_strides),tuple_times(filter_pos,filter_dilation)))
# Add to the sum-of-products. # Add to the sum-of-products.
output_batch[output_index] = output_batch[output_index] + filter[(co,) + filter_index] * img_batch[input_index] output_batch[output_index] = output_batch[output_index] + filter[(co,) + filter_index] * img_batch[input_index]
...@@ -160,11 +180,11 @@ def data_str(data): ...@@ -160,11 +180,11 @@ def data_str(data):
return result return result
def emit_test(t,f): def emit_test(t,f):
test_name, input_batch_data, filter_data, move_strides, dilation_strides, below_pads, above_pads = t test_name, input_batch_data, filter_data, move_strides, filter_dilation, below_pads, above_pads, image_dilation = t
print ("Generating convolution test '%s'..." % test_name) print ("Generating convolution test '%s'..." % test_name)
output_batch_data = convolution_ref(input_batch_data,filter_data,move_strides,dilation_strides,below_pads,above_pads) output_batch_data = convolution_ref(input_batch_data,filter_data,move_strides,filter_dilation,below_pads,above_pads,image_dilation)
template = ''' template = '''
TEST (${BACKEND_NAME}, %s) TEST (${BACKEND_NAME}, %s)
...@@ -175,7 +195,13 @@ TEST (${BACKEND_NAME}, %s) ...@@ -175,7 +195,13 @@ TEST (${BACKEND_NAME}, %s)
auto B = make_shared<op::Parameter>(element::f64, shape_b); auto B = make_shared<op::Parameter>(element::f64, shape_b);
auto shape_r = Shape{%s}; auto shape_r = Shape{%s};
auto f = make_shared<Function>( auto f = make_shared<Function>(
make_shared<op::Convolution>(A, B, Strides{%s}, Strides{%s}, Shape{%s}, Shape{%s}), op::Parameters{A, B}); make_shared<op::Convolution>(A, B,
Strides{%s}, // move_strides
Strides{%s}, // filter_dilation
Shape{%s}, // below_pads
Shape{%s}, // above_pads
Strides{%s}), // image_dilation
op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}"); auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f); auto external = manager->compile(f);
...@@ -201,35 +227,56 @@ TEST (${BACKEND_NAME}, %s) ...@@ -201,35 +227,56 @@ TEST (${BACKEND_NAME}, %s)
shape_str(filter_data.shape), shape_str(filter_data.shape),
shape_str(output_batch_data.shape), shape_str(output_batch_data.shape),
shape_str(move_strides), shape_str(move_strides),
shape_str(dilation_strides), shape_str(filter_dilation),
shape_str(below_pads), shape_str(below_pads),
shape_str(above_pads), shape_str(above_pads),
shape_str(image_dilation),
data_str(input_batch_data), data_str(input_batch_data),
data_str(filter_data), data_str(filter_data),
data_str(output_batch_data))); data_str(output_batch_data)));
# test name input image batch filters stride dilation below-pads above-pads # filter image
# test name input image batch filters stride dilation below-pads above-pads dilation
tests = [ tests = [
("convolution_2d_1image", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0)), ("convolution_2d_1image", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0), (1,1)),
("convolution_2d_1image_padded_1_1x1_1", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (1,1), (1,1)), ("convolution_2d_1image_padded_1_1x1_1", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (1,1), (1,1), (1,1)),
("convolution_2d_1image_padded_2_3x4_5", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (2,3), (4,5)), ("convolution_2d_1image_padded_2_3x4_5", shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (2,3), (4,5), (1,1)),
("convolution_2d_2images", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0)), ("convolution_2d_2images", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0), (1,1)),
("convolution_2d_2images_strided", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (0,0), (0,0)), ("convolution_2d_2images_strided", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (0,0), (0,0), (1,1)),
("convolution_2d_2images_strided_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (4,2), (5,7)), ("convolution_2d_2images_strided_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (4,2), (5,7), (1,1)),
("convolution_2d_2images_strided_padded_same", ("convolution_2d_2images_strided_padded_same",
shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (2,2), (2,2)), shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (2,2), (1,1), (2,2), (2,2), (1,1)),
("convolution_2d_2images_dilated", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (0,0), (0,0)), ("convolution_2d_2images_dilated", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (0,0), (0,0), (1,1)),
("convolution_2d_2images_dilated_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (4,2), (5,7)), ("convolution_2d_2images_dilated_padded", shaped_linspace((2,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (2,2), (4,2), (5,7), (1,1)),
("convolution_3d_2images", shaped_linspace((2,1,3,5,8)), shaped_linspace((2,1,2,2,3)), (1,1,1), (1,1,1), (0,0,0), (0,0,0)), ("convolution_3d_2images", shaped_linspace((2,1,3,5,8)), shaped_linspace((2,1,2,2,3)), (1,1,1), (1,1,1), (0,0,0), (0,0,0), (1,1,1)),
("convolution_4d_2images", shaped_linspace((2,1,3,5,8,7)),shaped_linspace((2,1,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0)), ("convolution_4d_2images", shaped_linspace((2,1,3,5,8,7)),shaped_linspace((2,1,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0), (1,1,1,1)),
("convolution_4d_4images", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0)), ("convolution_4d_4images", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(1,1,1,1),(0,0,0,0), (0,0,0,0), (1,1,1,1)),
("convolution_4d_4images_strided", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(2,1,3,2),(1,1,1,1),(0,0,0,0), (0,0,0,0)), ("convolution_4d_4images_strided", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(2,1,3,2),(1,1,1,1),(0,0,0,0), (0,0,0,0), (1,1,1,1)),
("convolution_4d_4images_dilated", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(2,1,3,2),(0,0,0,0), (0,0,0,0)), ("convolution_4d_4images_dilated", shaped_linspace((4,3,3,5,8,7)),shaped_linspace((4,3,2,2,3,1)),(1,1,1,1),(2,1,3,2),(0,0,0,0), (0,0,0,0), (1,1,1,1)),
("convolution_4d_4images_strided_dilated",shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(0,0,0,0), (0,0,0,0)), ("convolution_4d_4images_strided_dilated",shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(0,0,0,0), (0,0,0,0), (1,1,1,1)),
("convolution_4d_4images_strided_dilated_padded", ("convolution_4d_4images_strided_dilated_padded",
shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(2,4,6,8), (1,3,5,7)), shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(2,4,6,8), (1,3,5,7), (1,1,1,1)),
("convolution_4d_4images_strided_dilated_padded_same", ("convolution_4d_4images_strided_dilated_padded_same",
shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(3,3,3,3), (3,3,3,3)), shaped_linspace((4,3,8,8,8,8)),shaped_linspace((4,3,2,2,3,1)),(3,2,2,3),(2,1,3,2),(3,3,3,3), (3,3,3,3), (1,1,1,1)),
("convolution_2d_1image_1o1i_img_dilated",shaped_linspace((1,1,3,5)), shaped_linspace((1,1,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_1image_2o1i_img_dilated",shaped_linspace((1,1,3,5)), shaped_linspace((2,1,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_1image_2o2i_img_dilated",shaped_linspace((1,2,3,5)), shaped_linspace((2,2,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_1image_5o3i_img_dilated",shaped_linspace((1,3,3,5)), shaped_linspace((5,3,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_8image_5o3i_img_dilated",shaped_linspace((8,3,3,5)), shaped_linspace((5,3,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_8image_large_5o3i_img_dilated",
shaped_linspace((8,3,16,16)), shaped_linspace((5,3,2,2)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_8image_large_5o3i_uneven_filter_img_dilated",
shaped_linspace((8,3,16,16)), shaped_linspace((5,3,2,3)), (1,1), (1,1), (0,0), (0,0), (2,2)),
("convolution_2d_8image_large_5o3i_uneven_filter_uneven_img_dilation_img_dilated",
shaped_linspace((8,3,16,16)), shaped_linspace((5,3,2,3)), (1,1), (1,1), (0,0), (0,0), (2,3)),
("convolution_3d_2image_large_5o3i_uneven_filter_uneven_img_dilation_img_dilated",
shaped_linspace((2,3,8,8,8)), shaped_linspace((5,3,2,3,4)), (1,1,1), (1,1,1), (0,0,0), (0,0,0), (2,3,2)),
("convolution_3d_1image_large_5o3i_padded_uneven_filter_uneven_img_dilation_img_dilated",
shaped_linspace((1,3,8,8,8)), shaped_linspace((5,3,2,3,4)), (1,1,1), (1,1,1), (2,1,2), (1,2,3), (2,3,2)),
("convolution_3d_2image_large_5o3i_padded_strided_uneven_filter_uneven_img_dilation_img_dilated",
shaped_linspace((2,3,8,8,8)), shaped_linspace((5,3,2,3,4)), (2,3,2), (1,1,1), (2,1,2), (1,2,3), (2,3,2)),
("convolution_3d_2image_large_5o3i_padded_strided_uneven_filter_uneven_img_dilation_filter_dilated_img_dilated",
shaped_linspace((2,3,8,8,8)), shaped_linspace((5,3,2,3,4)), (2,3,2), (3,2,2), (2,1,2), (1,2,3), (2,3,2)),
] ]
def main(): def main():
...@@ -289,6 +336,7 @@ static bool all_close_d(const std::vector<double>& a, ...@@ -289,6 +336,7 @@ static bool all_close_d(const std::vector<double>& a,
double atol = 1e-8) double atol = 1e-8)
{ {
assert(a.size() == b.size()); assert(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i) for (size_t i = 0; i < a.size(); ++i)
{ {
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i])) if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
......
...@@ -1765,6 +1765,7 @@ TEST(type_prop, conv_1d_deduce) ...@@ -1765,6 +1765,7 @@ TEST(type_prop, conv_1d_deduce)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{1}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{1});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{0}); EXPECT_EQ(conv->get_padding_below(), Shape{0});
EXPECT_EQ(conv->get_padding_above(), Shape{0}); EXPECT_EQ(conv->get_padding_above(), Shape{0});
...@@ -1772,8 +1773,8 @@ TEST(type_prop, conv_1d_deduce) ...@@ -1772,8 +1773,8 @@ TEST(type_prop, conv_1d_deduce)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{100});
EXPECT_EQ(conv->get_output_image_shape(), Shape{91}); EXPECT_EQ(conv->get_output_image_shape(), Shape{91});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{10}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
...@@ -1799,6 +1800,7 @@ TEST(type_prop, conv_1d_deduce_padded) ...@@ -1799,6 +1800,7 @@ TEST(type_prop, conv_1d_deduce_padded)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{1}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{1});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{2}); EXPECT_EQ(conv->get_padding_below(), Shape{2});
EXPECT_EQ(conv->get_padding_above(), Shape{3}); EXPECT_EQ(conv->get_padding_above(), Shape{3});
...@@ -1806,8 +1808,8 @@ TEST(type_prop, conv_1d_deduce_padded) ...@@ -1806,8 +1808,8 @@ TEST(type_prop, conv_1d_deduce_padded)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{105}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{105});
EXPECT_EQ(conv->get_output_image_shape(), Shape{96}); EXPECT_EQ(conv->get_output_image_shape(), Shape{96});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{10}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
...@@ -1829,6 +1831,7 @@ TEST(type_prop, conv_1d_deduce_strided) ...@@ -1829,6 +1831,7 @@ TEST(type_prop, conv_1d_deduce_strided)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{2}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{2});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{0}); EXPECT_EQ(conv->get_padding_below(), Shape{0});
EXPECT_EQ(conv->get_padding_above(), Shape{0}); EXPECT_EQ(conv->get_padding_above(), Shape{0});
...@@ -1836,8 +1839,8 @@ TEST(type_prop, conv_1d_deduce_strided) ...@@ -1836,8 +1839,8 @@ TEST(type_prop, conv_1d_deduce_strided)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{100});
EXPECT_EQ(conv->get_output_image_shape(), Shape{46}); EXPECT_EQ(conv->get_output_image_shape(), Shape{46});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{10}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
...@@ -1863,6 +1866,7 @@ TEST(type_prop, conv_1d_deduce_strided_padded) ...@@ -1863,6 +1866,7 @@ TEST(type_prop, conv_1d_deduce_strided_padded)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{2}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{2});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{2}); EXPECT_EQ(conv->get_padding_below(), Shape{2});
EXPECT_EQ(conv->get_padding_above(), Shape{3}); EXPECT_EQ(conv->get_padding_above(), Shape{3});
...@@ -1870,8 +1874,8 @@ TEST(type_prop, conv_1d_deduce_strided_padded) ...@@ -1870,8 +1874,8 @@ TEST(type_prop, conv_1d_deduce_strided_padded)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{105}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{105});
EXPECT_EQ(conv->get_output_image_shape(), Shape{48}); EXPECT_EQ(conv->get_output_image_shape(), Shape{48});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{10}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
...@@ -1893,6 +1897,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_uneven) ...@@ -1893,6 +1897,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_uneven)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{2}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{2});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{0}); EXPECT_EQ(conv->get_padding_below(), Shape{0});
EXPECT_EQ(conv->get_padding_above(), Shape{0}); EXPECT_EQ(conv->get_padding_above(), Shape{0});
...@@ -1900,8 +1905,8 @@ TEST(type_prop, conv_1d_deduce_strided_small_uneven) ...@@ -1900,8 +1905,8 @@ TEST(type_prop, conv_1d_deduce_strided_small_uneven)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{5}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{5});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{5}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{5});
EXPECT_EQ(conv->get_output_image_shape(), Shape{2}); EXPECT_EQ(conv->get_output_image_shape(), Shape{2});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{2}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{2});
...@@ -1923,6 +1928,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_even) ...@@ -1923,6 +1928,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_even)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{2}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{2});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{0}); EXPECT_EQ(conv->get_padding_below(), Shape{0});
EXPECT_EQ(conv->get_padding_above(), Shape{0}); EXPECT_EQ(conv->get_padding_above(), Shape{0});
...@@ -1930,8 +1936,8 @@ TEST(type_prop, conv_1d_deduce_strided_small_even) ...@@ -1930,8 +1936,8 @@ TEST(type_prop, conv_1d_deduce_strided_small_even)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{6}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{6});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{6}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{6});
EXPECT_EQ(conv->get_output_image_shape(), Shape{3}); EXPECT_EQ(conv->get_output_image_shape(), Shape{3});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{2}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{2});
...@@ -1941,7 +1947,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_even) ...@@ -1941,7 +1947,7 @@ TEST(type_prop, conv_1d_deduce_strided_small_even)
EXPECT_EQ(conv->get_image_dimension_count(), 1); EXPECT_EQ(conv->get_image_dimension_count(), 1);
} }
TEST(type_prop, conv_1d_deduce_dilated) TEST(type_prop, conv_1d_deduce_window_dilated)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100});
...@@ -1954,6 +1960,7 @@ TEST(type_prop, conv_1d_deduce_dilated) ...@@ -1954,6 +1960,7 @@ TEST(type_prop, conv_1d_deduce_dilated)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{1}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{1});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{2}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{2});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{0}); EXPECT_EQ(conv->get_padding_below(), Shape{0});
EXPECT_EQ(conv->get_padding_above(), Shape{0}); EXPECT_EQ(conv->get_padding_above(), Shape{0});
...@@ -1961,18 +1968,18 @@ TEST(type_prop, conv_1d_deduce_dilated) ...@@ -1961,18 +1968,18 @@ TEST(type_prop, conv_1d_deduce_dilated)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{100});
EXPECT_EQ(conv->get_output_image_shape(), Shape{82}); EXPECT_EQ(conv->get_output_image_shape(), Shape{82});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{19}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
EXPECT_EQ(conv->get_window_virtual_shape(), Shape{10}); EXPECT_EQ(conv->get_window_virtual_shape(), Shape{19});
EXPECT_EQ(conv->get_batch_size(), 64); EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 1); EXPECT_EQ(conv->get_image_dimension_count(), 1);
} }
TEST(type_prop, conv_1d_deduce_dilated_padded) TEST(type_prop, conv_1d_deduce_window_dilated_padded)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100});
...@@ -1988,6 +1995,7 @@ TEST(type_prop, conv_1d_deduce_dilated_padded) ...@@ -1988,6 +1995,7 @@ TEST(type_prop, conv_1d_deduce_dilated_padded)
EXPECT_EQ(conv->get_window_movement_strides(), Strides{1}); EXPECT_EQ(conv->get_window_movement_strides(), Strides{1});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{2}); EXPECT_EQ(conv->get_window_dilation_strides(), Strides{2});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{1});
EXPECT_EQ(conv->get_padding_below(), Shape{2}); EXPECT_EQ(conv->get_padding_below(), Shape{2});
EXPECT_EQ(conv->get_padding_above(), Shape{3}); EXPECT_EQ(conv->get_padding_above(), Shape{3});
...@@ -1995,12 +2003,53 @@ TEST(type_prop, conv_1d_deduce_dilated_padded) ...@@ -1995,12 +2003,53 @@ TEST(type_prop, conv_1d_deduce_dilated_padded)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), Shape{100}); EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_padded_input_image_shape(), Shape{105}); EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{105});
EXPECT_EQ(conv->get_output_image_shape(), Shape{87}); EXPECT_EQ(conv->get_output_image_shape(), Shape{87});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{19}); EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
EXPECT_EQ(conv->get_window_virtual_shape(), Shape{10}); EXPECT_EQ(conv->get_window_virtual_shape(), Shape{19});
EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 1);
}
TEST(type_prop, conv_1d_deduce_window_dilated_images_dilated_padded)
{
// Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10});
auto move_strides = Strides{1};
auto dilate_strides = Strides{2};
auto padding_below = Shape{2};
auto padding_above = Shape{3};
auto img_dilate_strides = Strides{3};
auto conv = make_shared<op::Convolution>(param0,
param1,
move_strides,
dilate_strides,
padding_below,
padding_above,
img_dilate_strides);
EXPECT_EQ(conv->get_element_type(), element::f32);
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 285}));
EXPECT_EQ(conv->get_window_movement_strides(), Strides{1});
EXPECT_EQ(conv->get_window_dilation_strides(), Strides{2});
EXPECT_EQ(conv->get_image_dilation_strides(), Strides{3});
EXPECT_EQ(conv->get_padding_below(), Shape{2});
EXPECT_EQ(conv->get_padding_above(), Shape{3});
EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_physical_shape(), Shape{100});
EXPECT_EQ(conv->get_input_image_virtual_shape(), Shape{303});
EXPECT_EQ(conv->get_output_image_shape(), Shape{285});
EXPECT_EQ(conv->get_window_physical_shape(), Shape{10});
EXPECT_EQ(conv->get_window_virtual_shape(), Shape{19});
EXPECT_EQ(conv->get_batch_size(), 64); EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 1); EXPECT_EQ(conv->get_image_dimension_count(), 1);
...@@ -2017,6 +2066,7 @@ TEST(type_prop, conv_2d_deduce) ...@@ -2017,6 +2066,7 @@ TEST(type_prop, conv_2d_deduce)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{1, 1})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0}));
...@@ -2024,8 +2074,8 @@ TEST(type_prop, conv_2d_deduce) ...@@ -2024,8 +2074,8 @@ TEST(type_prop, conv_2d_deduce)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{91, 131})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{91, 131}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20}));
...@@ -2051,6 +2101,7 @@ TEST(type_prop, conv_2d_deduce_padded) ...@@ -2051,6 +2101,7 @@ TEST(type_prop, conv_2d_deduce_padded)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{1, 1})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{2, 3})); EXPECT_EQ(conv->get_padding_below(), (Shape{2, 3}));
EXPECT_EQ(conv->get_padding_above(), (Shape{3, 4})); EXPECT_EQ(conv->get_padding_above(), (Shape{3, 4}));
...@@ -2058,8 +2109,8 @@ TEST(type_prop, conv_2d_deduce_padded) ...@@ -2058,8 +2109,8 @@ TEST(type_prop, conv_2d_deduce_padded)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{105, 157})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{105, 157}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{96, 138})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{96, 138}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20}));
...@@ -2081,6 +2132,7 @@ TEST(type_prop, conv_2d_deduce_strided) ...@@ -2081,6 +2132,7 @@ TEST(type_prop, conv_2d_deduce_strided)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0}));
...@@ -2088,8 +2140,8 @@ TEST(type_prop, conv_2d_deduce_strided) ...@@ -2088,8 +2140,8 @@ TEST(type_prop, conv_2d_deduce_strided)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{46, 44})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{46, 44}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20}));
...@@ -2099,7 +2151,7 @@ TEST(type_prop, conv_2d_deduce_strided) ...@@ -2099,7 +2151,7 @@ TEST(type_prop, conv_2d_deduce_strided)
EXPECT_EQ(conv->get_image_dimension_count(), 2); EXPECT_EQ(conv->get_image_dimension_count(), 2);
} }
TEST(type_prop, conv_2d_deduce_strided_dilated) TEST(type_prop, conv_2d_deduce_strided_window_dilated)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100, 150}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100, 150});
...@@ -2112,6 +2164,7 @@ TEST(type_prop, conv_2d_deduce_strided_dilated) ...@@ -2112,6 +2164,7 @@ TEST(type_prop, conv_2d_deduce_strided_dilated)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0}));
...@@ -2119,18 +2172,59 @@ TEST(type_prop, conv_2d_deduce_strided_dilated) ...@@ -2119,18 +2172,59 @@ TEST(type_prop, conv_2d_deduce_strided_dilated)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{100, 150})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{37, 38})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{37, 38}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{28, 39})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20}));
EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{10, 20})); EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{28, 39}));
EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 2);
}
TEST(type_prop, conv_2d_deduce_strided_window_dilated_images_dilated)
{
// Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 100, 150});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{128, 3, 10, 20});
auto move_strides = Strides{2, 3};
auto dilate_strides = Strides{3, 2};
auto padding_below = Shape{0, 0};
auto padding_above = Shape{0, 0};
auto img_dilate_strides = Strides{2, 3};
auto conv = make_shared<op::Convolution>(param0,
param1,
move_strides,
dilate_strides,
padding_below,
padding_above,
img_dilate_strides);
EXPECT_EQ(conv->get_element_type(), element::f32);
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 86, 137}));
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{2, 3}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0}));
EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{100, 150}));
EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{199, 448}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{86, 137}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{10, 20}));
EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{28, 39}));
EXPECT_EQ(conv->get_batch_size(), 64); EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 2); EXPECT_EQ(conv->get_image_dimension_count(), 2);
} }
TEST(type_prop, conv_2d_deduce_strided_dilated_small) TEST(type_prop, conv_2d_deduce_strided_window_dilated_small)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 8}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 8});
...@@ -2143,6 +2237,7 @@ TEST(type_prop, conv_2d_deduce_strided_dilated_small) ...@@ -2143,6 +2237,7 @@ TEST(type_prop, conv_2d_deduce_strided_dilated_small)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0})); EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0}));
...@@ -2150,18 +2245,18 @@ TEST(type_prop, conv_2d_deduce_strided_dilated_small) ...@@ -2150,18 +2245,18 @@ TEST(type_prop, conv_2d_deduce_strided_dilated_small)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{7, 8})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{7, 8}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{7, 8})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{7, 8}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{2, 2})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{2, 2}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{4, 5})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{2, 3}));
EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{2, 3})); EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{4, 5}));
EXPECT_EQ(conv->get_batch_size(), 64); EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 2); EXPECT_EQ(conv->get_image_dimension_count(), 2);
} }
TEST(type_prop, conv_3d_deduce_strided_dilated_small) TEST(type_prop, conv_3d_deduce_strided_window_dilated_small)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 8, 10}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 8, 10});
...@@ -2174,6 +2269,7 @@ TEST(type_prop, conv_3d_deduce_strided_dilated_small) ...@@ -2174,6 +2269,7 @@ TEST(type_prop, conv_3d_deduce_strided_dilated_small)
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3, 4})); EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3, 4}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2, 2})); EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2, 2}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{1, 1, 1}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0, 0})); EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0, 0})); EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0, 0}));
...@@ -2181,12 +2277,53 @@ TEST(type_prop, conv_3d_deduce_strided_dilated_small) ...@@ -2181,12 +2277,53 @@ TEST(type_prop, conv_3d_deduce_strided_dilated_small)
EXPECT_EQ(conv->get_input_channel_count(), 3); EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128); EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_shape(), (Shape{7, 8, 10})); EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{7, 8, 10}));
EXPECT_EQ(conv->get_padded_input_image_shape(), (Shape{7, 8, 10})); EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{7, 8, 10}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{2, 2, 2})); EXPECT_EQ(conv->get_output_image_shape(), (Shape{2, 2, 2}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{4, 5, 3})); EXPECT_EQ(conv->get_window_physical_shape(), (Shape{2, 3, 2}));
EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{2, 3, 2})); EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{4, 5, 3}));
EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 3);
}
TEST(type_prop, conv_3d_deduce_strided_window_dilated_image_dilated_small)
{
// Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{64, 3, 7, 8, 10});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{128, 3, 2, 3, 2});
auto move_strides = Strides{2, 3, 4};
auto dilate_strides = Strides{3, 2, 2};
auto below_padding = Shape{0, 0, 0};
auto above_padding = Shape{0, 0, 0};
auto img_dilate_strides = Strides{2, 3, 2};
auto conv = make_shared<op::Convolution>(param0,
param1,
move_strides,
dilate_strides,
below_padding,
above_padding,
img_dilate_strides);
EXPECT_EQ(conv->get_element_type(), element::f32);
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 5, 6, 5}));
EXPECT_EQ(conv->get_window_movement_strides(), (Strides{2, 3, 4}));
EXPECT_EQ(conv->get_window_dilation_strides(), (Strides{3, 2, 2}));
EXPECT_EQ(conv->get_image_dilation_strides(), (Strides{2, 3, 2}));
EXPECT_EQ(conv->get_padding_below(), (Shape{0, 0, 0}));
EXPECT_EQ(conv->get_padding_above(), (Shape{0, 0, 0}));
EXPECT_EQ(conv->get_input_channel_count(), 3);
EXPECT_EQ(conv->get_output_channel_count(), 128);
EXPECT_EQ(conv->get_input_image_physical_shape(), (Shape{7, 8, 10}));
EXPECT_EQ(conv->get_input_image_virtual_shape(), (Shape{13, 22, 19}));
EXPECT_EQ(conv->get_output_image_shape(), (Shape{5, 6, 5}));
EXPECT_EQ(conv->get_window_physical_shape(), (Shape{2, 3, 2}));
EXPECT_EQ(conv->get_window_virtual_shape(), (Shape{4, 5, 3}));
EXPECT_EQ(conv->get_batch_size(), 64); EXPECT_EQ(conv->get_batch_size(), 64);
EXPECT_EQ(conv->get_image_dimension_count(), 3); EXPECT_EQ(conv->get_image_dimension_count(), 3);
...@@ -2429,7 +2566,7 @@ TEST(type_prop, conv_invalid_movement_stride_rank) ...@@ -2429,7 +2566,7 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
} }
} }
TEST(type_prop, conv_invalid_dilation_stride_rank) TEST(type_prop, conv_invalid_window_dilation_stride_rank)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10});
...@@ -2439,7 +2576,7 @@ TEST(type_prop, conv_invalid_dilation_stride_rank) ...@@ -2439,7 +2576,7 @@ TEST(type_prop, conv_invalid_dilation_stride_rank)
auto conv = make_shared<op::Convolution>(param0, param1, Strides{2, 3}, Strides{2, 3, 8}); auto conv = make_shared<op::Convolution>(param0, param1, Strides{2, 3}, Strides{2, 3, 8});
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong dilation stride rank not detected"; FAIL() << "Invalid input with wrong window dilation stride rank not detected";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
...@@ -2453,6 +2590,36 @@ TEST(type_prop, conv_invalid_dilation_stride_rank) ...@@ -2453,6 +2590,36 @@ TEST(type_prop, conv_invalid_dilation_stride_rank)
} }
} }
TEST(type_prop, conv_invalid_image_dilation_stride_rank)
{
// Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 3, 3});
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
Strides{2, 3},
Strides{2, 3},
Shape{0, 0},
Shape{0, 0},
Strides{2, 3, 8});
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong image dilation stride rank not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("Convolution image dilation stride rank does not "
"match number of image dimensions."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_invalid_padding_below_rank) TEST(type_prop, conv_invalid_padding_below_rank)
{ {
// Deduce type // Deduce type
...@@ -2517,8 +2684,10 @@ TEST(type_prop, conv_invalid_input_image_size_0) ...@@ -2517,8 +2684,10 @@ TEST(type_prop, conv_invalid_input_image_size_0)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), EXPECT_EQ(
std::string("Convolution input image dimension is zero even with padding.")); error.what(),
std::string(
"Convolution input image dimension after dilation is zero even with padding."));
} }
catch (...) catch (...)
{ {
...@@ -2548,7 +2717,7 @@ TEST(type_prop, conv_invalid_window_size_0) ...@@ -2548,7 +2717,7 @@ TEST(type_prop, conv_invalid_window_size_0)
} }
} }
TEST(type_prop, conv_invalid_dilation_stride_0) TEST(type_prop, conv_invalid_window_dilation_stride_0)
{ {
// Deduce type // Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10}); auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10});
...@@ -2558,7 +2727,7 @@ TEST(type_prop, conv_invalid_dilation_stride_0) ...@@ -2558,7 +2727,7 @@ TEST(type_prop, conv_invalid_dilation_stride_0)
auto conv = make_shared<op::Convolution>(param0, param1, Strides{2, 3}, Strides{2, 0}); auto conv = make_shared<op::Convolution>(param0, param1, Strides{2, 3}, Strides{2, 0});
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length dilation stride axis not detected"; FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
...@@ -2570,6 +2739,29 @@ TEST(type_prop, conv_invalid_dilation_stride_0) ...@@ -2570,6 +2739,29 @@ TEST(type_prop, conv_invalid_dilation_stride_0)
} }
} }
TEST(type_prop, conv_invalid_image_dilation_stride_0)
{
// Deduce type
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 10, 10});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{6, 2, 3, 3});
try
{
auto conv = make_shared<op::Convolution>(
param0, param1, Strides{2, 3}, Strides{2, 3}, Shape{0, 0}, Shape{0, 0}, Strides{2, 0});
// Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length image dilation stride axis not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Convolution image dilation stride is zero."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_invalid_dilated_window_too_large) TEST(type_prop, conv_invalid_dilated_window_too_large)
{ {
// Deduce type // Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment