Unverified Commit c682fbf4 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Image batch dilation for convolution (#363)

Sub-PR: image dilation tests (#362) via @adstraw 
parent 74850150
This diff is collapsed.
...@@ -31,8 +31,17 @@ namespace ngraph ...@@ -31,8 +31,17 @@ namespace ngraph
const Coordinate& source_end_corner, const Coordinate& source_end_corner,
const Strides& source_strides, const Strides& source_strides,
const AxisVector& source_axis_order, const AxisVector& source_axis_order,
const Shape& source_padding_below, const Shape& target_padding_below,
const Shape& source_padding_above); const Shape& target_padding_above,
const Strides& source_dilation_strides);
CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner,
const Coordinate& source_end_corner,
const Strides& source_strides,
const AxisVector& source_axis_order,
const Shape& target_padding_below,
const Shape& target_padding_above);
CoordinateTransform(const Shape& source_shape, CoordinateTransform(const Shape& source_shape,
const Coordinate& source_start_corner, const Coordinate& source_start_corner,
...@@ -52,8 +61,7 @@ namespace ngraph ...@@ -52,8 +61,7 @@ namespace ngraph
CoordinateTransform(const Shape& source_shape); CoordinateTransform(const Shape& source_shape);
size_t index(const Coordinate& c) const; size_t index(const Coordinate& c) const;
bool in_bounds(const Coordinate& c) const; bool has_source_coordinate(const Coordinate& c) const;
bool in_padding(const Coordinate& c) const;
Coordinate to_source_coordinate(const Coordinate& c) const; Coordinate to_source_coordinate(const Coordinate& c) const;
Coordinate get_target_shape() const; Coordinate get_target_shape() const;
...@@ -62,6 +70,7 @@ namespace ngraph ...@@ -62,6 +70,7 @@ namespace ngraph
Coordinate get_source_end_corner() { return m_source_end_corner; } Coordinate get_source_end_corner() { return m_source_end_corner; }
Strides get_source_strides() { return m_source_strides; } Strides get_source_strides() { return m_source_strides; }
AxisVector get_source_axis_order() { return m_source_axis_order; } AxisVector get_source_axis_order() { return m_source_axis_order; }
Strides get_target_dilation_strides() { return m_target_dilation_strides; }
class Iterator class Iterator
{ {
public: public:
...@@ -86,9 +95,9 @@ namespace ngraph ...@@ -86,9 +95,9 @@ namespace ngraph
Iterator end() noexcept { return Iterator(m_target_shape, true); } Iterator end() noexcept { return Iterator(m_target_shape, true); }
private: private:
size_t index_source(const Coordinate& c) const; size_t index_source(const Coordinate& c) const;
static Strides default_strides(size_t n_axes);
static Shape default_padding(size_t n_axes); static Shape default_padding(size_t n_axes);
static AxisVector default_axis_order(size_t n_axes); static AxisVector default_axis_order(size_t n_axes);
static Strides default_source_strides(size_t n_axes);
static Coordinate default_source_start_corner(size_t n_axes); static Coordinate default_source_start_corner(size_t n_axes);
static Coordinate default_source_end_corner(const Shape& source_shape); static Coordinate default_source_end_corner(const Shape& source_shape);
...@@ -97,8 +106,9 @@ namespace ngraph ...@@ -97,8 +106,9 @@ namespace ngraph
Shape m_source_end_corner; Shape m_source_end_corner;
Strides m_source_strides; Strides m_source_strides;
AxisVector m_source_axis_order; AxisVector m_source_axis_order;
Shape m_source_padding_below; Shape m_target_padding_below;
Shape m_source_padding_above; Shape m_target_padding_above;
Strides m_target_dilation_strides;
Shape m_target_shape; Shape m_target_shape;
size_t m_n_axes; size_t m_n_axes;
......
...@@ -23,12 +23,14 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -23,12 +23,14 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
const Strides& image_dilation_strides)
: RequiresTensorViewArgs("Convolution", {image_batch, filters}) : RequiresTensorViewArgs("Convolution", {image_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_image_dilation_strides(image_dilation_strides)
{ {
auto& image_batch_shape = get_inputs().at(0).get_shape(); auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& filters_shape = get_inputs().at(1).get_shape(); auto& filters_shape = get_inputs().at(1).get_shape();
...@@ -77,7 +79,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -77,7 +79,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
} }
// //
// Make sure window movement strides and window dilation strades have same rank as Di. // Make sure window movement strides, window dilation strides, and image dilation strides
// have same rank as Di.
// //
if (m_window_movement_strides.size() != m_image_dimension_count) if (m_window_movement_strides.size() != m_image_dimension_count)
{ {
...@@ -91,6 +94,12 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -91,6 +94,12 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
"Convolution window dilation stride rank does not match number of image dimensions."); "Convolution window dilation stride rank does not match number of image dimensions.");
} }
if (m_image_dilation_strides.size() != m_image_dimension_count)
{
throw ngraph_error(
"Convolution image dilation stride rank does not match number of image dimensions.");
}
// //
// Make sure padding-below and padding-above shapes have same rank as Di. // Make sure padding-below and padding-above shapes have same rank as Di.
// //
...@@ -107,28 +116,36 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -107,28 +116,36 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
} }
// //
// Extract input image shape Di and make sure all dimensions are larger than 0 after padding. // Extract input image shape Di and make sure all dimensions are larger than 0 after padding and dilation.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < m_image_dimension_count; i++)
{ {
m_input_image_shape.push_back(image_batch_shape[1 + 1 + i]); if (image_dilation_strides[i] == 0)
m_padded_input_image_shape.push_back(padding_below[i] + image_batch_shape[1 + 1 + i] + {
padding_above[i]); throw ngraph_error("Convolution image dilation stride is zero.");
}
if (m_padded_input_image_shape[i] == 0) size_t dim_size = image_batch_shape[1 + 1 + i];
m_input_image_physical_shape.push_back(dim_size);
size_t dilated_dim_size = (dim_size - 1) * image_dilation_strides[i] + 1;
size_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
m_input_image_virtual_shape.push_back(padded_dilated_dim_size);
if (m_input_image_virtual_shape[i] == 0)
{ {
throw ngraph_error("Convolution input image dimension is zero even with padding."); throw ngraph_error(
"Convolution input image dimension after dilation is zero even with padding.");
} }
} }
// //
// Extract the virtual shape Wv of the convolution window, *not* including dilation, from the filter dimensions. // Extract the physical shape Wp of the convolution window, *not* including dilation, from the filter dimensions.
// At the same time, make sure window shape dimensions are all larger than 0. // At the same time, make sure window shape dimensions are all larger than 0.
// //
for (size_t i = 0; i < m_image_dimension_count; i++) for (size_t i = 0; i < m_image_dimension_count; i++)
{ {
m_window_virtual_shape.push_back(filters_shape[1 + 1 + i]); m_window_physical_shape.push_back(filters_shape[1 + 1 + i]);
if (m_window_virtual_shape[i] == 0) if (m_window_physical_shape[i] == 0)
{ {
throw ngraph_error("Convolution window shape has a zero-length axis."); throw ngraph_error("Convolution window shape has a zero-length axis.");
} }
...@@ -145,10 +162,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -145,10 +162,10 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
throw ngraph_error("Convolution window axis dilation stride is zero."); throw ngraph_error("Convolution window axis dilation stride is zero.");
} }
m_window_physical_shape.push_back( m_window_virtual_shape.push_back(
(m_window_virtual_shape[i] - 1) * m_window_dilation_strides[i] + 1); (m_window_physical_shape[i] - 1) * m_window_dilation_strides[i] + 1);
if (m_window_physical_shape[i] > m_padded_input_image_shape[i]) if (m_window_virtual_shape[i] > m_input_image_virtual_shape[i])
{ {
throw ngraph_error( throw ngraph_error(
"Convolution window after dilation is larger than the image even with padding."); "Convolution window after dilation is larger than the image even with padding.");
...@@ -165,7 +182,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -165,7 +182,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
throw ngraph_error("Convolution window axis movement stride is zero."); throw ngraph_error("Convolution window axis movement stride is zero.");
} }
m_output_image_shape.push_back( m_output_image_shape.push_back(
ceil_div(m_padded_input_image_shape[i] - m_window_physical_shape[i] + 1, ceil_div(m_input_image_virtual_shape[i] - m_window_virtual_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
...@@ -180,7 +197,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -180,7 +197,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape); set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
} }
Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -190,23 +207,26 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch) ...@@ -190,23 +207,26 @@ Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
"Convolution image batch input must have rank of at least 3 (one batch axis, one " "Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension)."); "input-channel axis, at least one image dimension).");
} }
return Shape(image_batch_shape.size() - 2, 0); return Strides(image_batch_shape.size() - 2, 1);
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters, const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides,
const Shape& padding_below,
const Shape& padding_above)
: Convolution(image_batch, : Convolution(image_batch,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
default_padding(image_batch), padding_below,
default_padding(image_batch)) padding_above,
default_strides(image_batch))
{ {
} }
Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batch) Shape op::Convolution::default_padding(const std::shared_ptr<Node>& image_batch)
{ {
auto& image_batch_shape = image_batch->get_shape(); auto& image_batch_shape = image_batch->get_shape();
if (image_batch_shape.size() < 3) if (image_batch_shape.size() < 3)
...@@ -216,7 +236,20 @@ Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batc ...@@ -216,7 +236,20 @@ Strides op::Convolution::default_strides(const std::shared_ptr<Node>& image_batc
"Convolution image batch input must have rank of at least 3 (one batch axis, one " "Convolution image batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one image dimension)."); "input-channel axis, at least one image dimension).");
} }
return Strides(image_batch_shape.size() - 2, 1); return Shape(image_batch_shape.size() - 2, 0);
}
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
const std::shared_ptr<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides)
: Convolution(image_batch,
filters,
window_movement_strides,
window_dilation_strides,
default_padding(image_batch),
default_padding(image_batch))
{
} }
op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
...@@ -254,7 +287,8 @@ std::shared_ptr<Node> ...@@ -254,7 +287,8 @@ std::shared_ptr<Node>
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
m_padding_below, m_padding_below,
m_padding_above); m_padding_above,
m_image_dilation_strides);
} }
bool op::Convolution::is_functionally_identical(const Node& other) const bool op::Convolution::is_functionally_identical(const Node& other) const
...@@ -265,9 +299,13 @@ bool op::Convolution::is_functionally_identical(const Node& other) const ...@@ -265,9 +299,13 @@ bool op::Convolution::is_functionally_identical(const Node& other) const
const Convolution& rhs = dynamic_cast<const Convolution&>(other); const Convolution& rhs = dynamic_cast<const Convolution&>(other);
rc &= m_window_movement_strides == rhs.m_window_movement_strides; rc &= m_window_movement_strides == rhs.m_window_movement_strides;
rc &= m_window_dilation_strides == rhs.m_window_dilation_strides; rc &= m_window_dilation_strides == rhs.m_window_dilation_strides;
rc &= m_padding_below == rhs.m_padding_below;
rc &= m_padding_above == rhs.m_padding_above;
rc &= m_image_dilation_strides == rhs.m_image_dilation_strides;
rc &= m_input_channel_count == rhs.m_input_channel_count; rc &= m_input_channel_count == rhs.m_input_channel_count;
rc &= m_output_channel_count == rhs.m_output_channel_count; rc &= m_output_channel_count == rhs.m_output_channel_count;
rc &= m_input_image_shape == rhs.m_input_image_shape; rc &= m_input_image_physical_shape == rhs.m_input_image_physical_shape;
rc &= m_input_image_virtual_shape == rhs.m_input_image_virtual_shape;
rc &= m_output_image_shape == rhs.m_output_image_shape; rc &= m_output_image_shape == rhs.m_output_image_shape;
rc &= m_window_physical_shape == rhs.m_window_physical_shape; rc &= m_window_physical_shape == rhs.m_window_physical_shape;
rc &= m_window_virtual_shape == rhs.m_window_virtual_shape; rc &= m_window_virtual_shape == rhs.m_window_virtual_shape;
......
This diff is collapsed.
...@@ -1678,7 +1678,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n, ...@@ -1678,7 +1678,9 @@ void runtime::cpu::CPU_Emitter::EmitConvolution(const ngraph::Node* n,
m_out << " {" << join(convolution->get_window_dilation_strides()) m_out << " {" << join(convolution->get_window_dilation_strides())
<< "},\n"; << "},\n";
m_out << " {" << join(convolution->get_padding_below()) << "},\n"; m_out << " {" << join(convolution->get_padding_below()) << "},\n";
m_out << " {" << join(convolution->get_padding_above()) << "});\n"; m_out << " {" << join(convolution->get_padding_above()) << "},\n";
m_out << " {" << join(convolution->get_image_dilation_strides())
<< "});\n";
} }
void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n, void runtime::cpu::CPU_Emitter::EmitNot(const ngraph::Node* n,
......
...@@ -303,7 +303,8 @@ private: ...@@ -303,7 +303,8 @@ private:
c->get_window_movement_strides(), c->get_window_movement_strides(),
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above()); c->get_padding_above(),
c->get_image_dilation_strides());
} }
else if (node_op == "Cos") else if (node_op == "Cos")
{ {
......
...@@ -36,7 +36,8 @@ namespace ngraph ...@@ -36,7 +36,8 @@ namespace ngraph
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above,
const Strides& image_dilation_strides)
{ {
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
...@@ -64,17 +65,18 @@ namespace ngraph ...@@ -64,17 +65,18 @@ namespace ngraph
// //
// (1,l_1,...,l_n). // (1,l_1,...,l_n).
// //
// Note that we are iterating within the *padded* image batch, so further down we must check // Note that we are iterating within the *padded* and *dilated* image batch, so further
// the current coordinate is in the padding. // down we must check the current coordinate is in the padding or dilation gap.
size_t n_image_dimensions = arg0_shape.size() - 2; size_t n_image_dimensions = arg0_shape.size() - 2;
size_t n_input_channels = arg0_shape[1]; size_t n_input_channels = arg0_shape[1];
Shape input_batch_transform_start(2 + n_image_dimensions); Shape input_batch_transform_start(2 + n_image_dimensions);
Shape input_batch_transform_end(2 + n_image_dimensions); Shape input_batch_transform_end(2 + n_image_dimensions);
Shape input_batch_transform_strides(2 + n_image_dimensions, 1); Shape input_batch_transform_movement_strides(2 + n_image_dimensions, 1);
Shape input_batch_padding_below(2 + n_image_dimensions, 0); Shape input_batch_transform_padding_below(2 + n_image_dimensions, 0);
Shape input_batch_padding_above(2 + n_image_dimensions, 0); Shape input_batch_transform_padding_above(2 + n_image_dimensions, 0);
Shape input_batch_transform_dilation_strides(2 + n_image_dimensions, 1);
input_batch_transform_start[0] = img_index; input_batch_transform_start[0] = img_index;
input_batch_transform_end[0] = img_index + 1; input_batch_transform_end[0] = img_index + 1;
...@@ -83,32 +85,37 @@ namespace ngraph ...@@ -83,32 +85,37 @@ namespace ngraph
for (size_t i = 2; i < n_image_dimensions + 2; i++) for (size_t i = 2; i < n_image_dimensions + 2; i++)
{ {
size_t dilation_stride = window_dilation_strides[i - 2]; size_t window_dilation_stride = window_dilation_strides[i - 2];
size_t movement_stride = window_movement_strides[i - 2]; size_t window_movement_stride = window_movement_strides[i - 2];
size_t below_pad = padding_below[i - 2]; size_t below_pad = padding_below[i - 2];
size_t above_pad = padding_above[i - 2]; size_t above_pad = padding_above[i - 2];
size_t image_dilation_stride = image_dilation_strides[i - 2];
input_batch_transform_start[i] = movement_stride * out_coord[i];
input_batch_transform_end[i] = input_batch_transform_start[i] + input_batch_transform_start[i] = window_movement_stride * out_coord[i];
(arg1_shape[i] - 1) * dilation_stride + 1; input_batch_transform_end[i] =
input_batch_transform_strides[i] = dilation_stride; input_batch_transform_start[i] +
input_batch_padding_below[i] = below_pad; (arg1_shape[i] - 1) * window_dilation_stride + 1;
input_batch_padding_above[i] = above_pad; input_batch_transform_movement_strides[i] = window_dilation_stride;
input_batch_transform_padding_below[i] = below_pad;
input_batch_transform_padding_above[i] = above_pad;
input_batch_transform_dilation_strides[i] = image_dilation_stride;
} }
AxisVector input_batch_axis_order(2 + n_image_dimensions); AxisVector input_batch_transform_axis_order(2 + n_image_dimensions);
size_t n = 0; size_t n = 0;
std::generate(input_batch_axis_order.begin(), std::generate(input_batch_transform_axis_order.begin(),
input_batch_axis_order.end(), input_batch_transform_axis_order.end(),
[&n]() -> size_t { return n++; }); [&n]() -> size_t { return n++; });
CoordinateTransform input_batch_transform(arg0_shape, CoordinateTransform input_batch_transform(
input_batch_transform_start, arg0_shape,
input_batch_transform_end, input_batch_transform_start,
input_batch_transform_strides, input_batch_transform_end,
input_batch_axis_order, input_batch_transform_movement_strides,
input_batch_padding_below, input_batch_transform_axis_order,
input_batch_padding_above); input_batch_transform_padding_below,
input_batch_transform_padding_above,
input_batch_transform_dilation_strides);
// Simultaneously with iterating I, for the filters we need to iterate the coordinate: // Simultaneously with iterating I, for the filters we need to iterate the coordinate:
// //
...@@ -151,10 +158,13 @@ namespace ngraph ...@@ -151,10 +158,13 @@ namespace ngraph
{ {
const Coordinate& input_batch_coord = *input_it; const Coordinate& input_batch_coord = *input_it;
const Coordinate& filter_coord = *filter_it; const Coordinate& filter_coord = *filter_it;
T v = input_batch_transform.in_padding(input_batch_coord)
? 0 T v = input_batch_transform.has_source_coordinate(input_batch_coord)
: arg0[input_batch_transform.index(input_batch_coord)]; ? arg0[input_batch_transform.index(input_batch_coord)]
: 0;
result += v * arg1[filter_transform.index(filter_coord)]; result += v * arg1[filter_transform.index(filter_coord)];
++input_it; ++input_it;
++filter_it; ++filter_it;
} }
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment