Commit 61be3814 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Refactor convolution and pooling type prop (#1817)

* WIP

* More WIP

* More chiseling

* Move conv validation utils to a separate file; update unit tests

* Fix invalid attributes in pattern containing ConvolutionBackpropFilters

* Remove zero_const_conv test (it's no longer possible to construct the graph being tested)

* Rename infer_convolution_output_item_shape to infer_windowed_reduction_output_shape and add a boolean flag to control whether window-all-in-padding is allowed

* Add generalized function for inferring pooling fprop, use it in AvgPool/AvgPoolBackprop

* Update MaxPool to use new utility functions

* Fix comment

* Remove faulty and redundant check for window shape relative to pre-padding data shape

* Revert change to pattern construction in cpu_fusion

* Update unit test for maxpool

* Restore unjustly eliminated tests; move some computation to ptrdiff_t for safety; fix wording on some error messages

* Formatting
parent 05aa1be8
...@@ -155,6 +155,7 @@ set (SRC ...@@ -155,6 +155,7 @@ set (SRC
strides.cpp strides.cpp
type/element_type.cpp type/element_type.cpp
util.cpp util.cpp
validation_util.cpp
graph_util.cpp graph_util.cpp
placement.cpp placement.cpp
cpio.cpp cpio.cpp
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -41,6 +42,10 @@ void op::AvgPool::validate_and_infer_types() ...@@ -41,6 +42,10 @@ void op::AvgPool::validate_and_infer_types()
{ {
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2) if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
{ {
m_window_movement_strides = Strides(arg_shape.size() - 2, 1); m_window_movement_strides = Strides(arg_shape.size() - 2, 1);
...@@ -56,145 +61,20 @@ void op::AvgPool::validate_and_infer_types() ...@@ -56,145 +61,20 @@ void op::AvgPool::validate_and_infer_types()
m_padding_above = Shape(arg_shape.size() - 2, 0); m_padding_above = Shape(arg_shape.size() - 2, 0);
} }
// // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// Make sure batch size and channel count are not zero, and that we have at least one spatial // now still take Shape (no negative padding).
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
// CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape set_output_type(0,
<< ")."; get_input_element_type(0),
infer_batched_pooling_forward(this,
size_t batch_size = arg_shape[0]; arg_shape,
NODE_VALIDATION_ASSERT(this, batch_size != 0) padding_below,
<< "Data batch size is zero (data input shape: " << arg_shape << ")."; padding_above,
m_window_shape,
size_t channel_count = arg_shape[1]; m_window_movement_strides,
NODE_VALIDATION_ASSERT(this, channel_count != 0) m_include_padding_in_avg_computation));
<< "Channel count is zero (data input shape: " << arg_shape << ").";
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
}
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
}
//
// Make sure the pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
//
// Make sure we're not going to have to compute average over an empty set of tensor elements.
// That will happen if the sliding window ever resides entirely over the padding area AND
// we're planning to disregard padding when computing the window's average.
//
if (!m_include_padding_in_avg_computation)
{
for (size_t i = 0; i < spatial_dimension_count; i++)
{
const size_t dim_virtual_size = input_item_virtual_shape[i];
const size_t dim_window_size = m_window_shape[i];
const size_t dim_stride = m_window_movement_strides[i];
const size_t dim_padding_below = m_padding_below[i];
const size_t dim_padding_above = m_padding_above[i];
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT(this,
dim_padding_below == 0 || dim_window_size > dim_padding_below)
<< "Window will sometimes reside entirely within the below-padding region, but"
<< " include_padding_in_avg_computation was not set (padding below: "
<< m_padding_below << ", window shape: " << m_window_shape << ").";
// Now check the upper-bound...
{
const size_t dim_num_strides = (dim_virtual_size - dim_window_size) / dim_stride;
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
NODE_VALIDATION_ASSERT(this,
dim_padding_above == 0 ||
dim_window_max_lower_offset <
dim_padding_above_start_offset)
<< "Window will sometimes reside entirely within the above-padding region, but"
<< " include_padding_in_avg_computation was not set (padding above: "
<< m_padding_above << ", window shape: " << m_window_shape << ").";
}
}
}
//
// Construct result shape: NCDo.
//
Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size;
result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
set_output_type(0, get_input_element_type(0), result_shape);
} }
op::AvgPool::AvgPool(const shared_ptr<Node>& arg, op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
...@@ -240,154 +120,25 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -240,154 +120,25 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types() void op::AvgPoolBackprop::validate_and_infer_types()
{ {
// --
// TODO: de-duplicate this code from AvgPool::AvgPool.
// --
auto& delta_shape = get_input_shape(0); auto& delta_shape = get_input_shape(0);
// // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// Make sure batch size and channel count are not zero, and that we have at least one spatial // now still take Shape (no negative padding).
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
// CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
NODE_VALIDATION_ASSERT(this, m_forward_arg_shape.size() >= 3)
<< "Forward input shape does not have rank of at least 3 (forward input shape: "
<< m_forward_arg_shape << ").";
size_t batch_size = m_forward_arg_shape[0];
NODE_VALIDATION_ASSERT(this, batch_size != 0)
<< "Data batch size is zero (forward input shape: " << m_forward_arg_shape << ").";
size_t channel_count = m_forward_arg_shape[1];
NODE_VALIDATION_ASSERT(this, channel_count != 0)
<< "Channel count is zero (forward input shape: " << m_forward_arg_shape << ").";
size_t spatial_dimension_count = m_forward_arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", forward input shape: " << m_forward_arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", forward input shape: " << m_forward_arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", forward input shape: " << m_forward_arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", forward input shape: " << m_forward_arg_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = m_forward_arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
}
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Forward input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
}
//
// Make sure the pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
//
// Make sure we're not going to have to compute average over an empty set of tensor elements.
// That will happen if the sliding window ever resides entirely over the padding area AND
// we're planning to disregard padding when computing the window's average.
//
if (!m_include_padding_in_avg_computation)
{
for (size_t i = 0; i < spatial_dimension_count; i++)
{
const size_t dim_virtual_size = input_item_virtual_shape[i];
const size_t dim_window_size = m_window_shape[i];
const size_t dim_stride = m_window_movement_strides[i];
const size_t dim_padding_below = m_padding_below[i];
const size_t dim_padding_above = m_padding_above[i];
// Checking the lower edge of each dimension is easy, because there's no mystery
// regarding the window's lower-edge placement...
NODE_VALIDATION_ASSERT(this,
dim_padding_below == 0 || dim_window_size > dim_padding_below)
<< "Window will sometimes reside entirely within the below-padding region, but"
<< " include_padding_in_avg_computation was not set (padding below: "
<< m_padding_below << ", window shape: " << m_window_shape << ").";
// Now check the upper-bound...
{
const size_t dim_num_strides = (dim_virtual_size - dim_window_size) / dim_stride;
const size_t dim_window_max_lower_offset = dim_num_strides * dim_stride;
const size_t dim_padding_above_start_offset = dim_virtual_size - dim_padding_above;
NODE_VALIDATION_ASSERT(this,
dim_padding_above == 0 ||
dim_window_max_lower_offset <
dim_padding_above_start_offset)
<< "Window will sometimes reside entirely within the above-padding region, but"
<< " include_padding_in_avg_computation was not set (padding above: "
<< m_padding_above << ", window shape: " << m_window_shape << ").";
}
}
}
// Shape forward_result_shape =
// Construct result shape: NCDo. infer_batched_pooling_forward(this,
// m_forward_arg_shape,
Shape forward_result_shape(1 + 1 + spatial_dimension_count); padding_below,
forward_result_shape[0] = batch_size; padding_above,
forward_result_shape[1] = channel_count; m_window_shape,
copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2); m_window_movement_strides,
m_include_padding_in_avg_computation);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape) NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
<< "Inferred forward output shape does not match delta shape (inferred forward output " << "Inferred forward output shape does not match delta shape (inferred forward output "
"shape: " << "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
<< forward_result_shape << ", delta shape: " << delta_shape << ").";
set_output_type(0, get_input_element_type(0), m_forward_arg_shape); set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
} }
......
...@@ -22,209 +22,11 @@ ...@@ -22,209 +22,11 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Shape op::util::infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result)
{
NODE_VALIDATION_ASSERT(node, batch_axis_data <= 1) << "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, input_channel_axis_data <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, input_channel_axis_filters <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, output_channel_axis_filters <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, batch_axis_result <= 1) << "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, output_channel_axis_result <= 1)
<< "(This is an internal nGraph error)";
//
// Make sure data_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
//
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch input must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
size_t batch_size = data_batch_shape[batch_axis_data];
NODE_VALIDATION_ASSERT(node, batch_size != 0)
<< "Data batch size is zero (data batch shape: " << data_batch_shape << ", "
<< "batch axis is axis " << batch_axis_data << ").";
size_t input_channel_count = data_batch_shape[input_channel_axis_data];
NODE_VALIDATION_ASSERT(node, input_channel_count != 0)
<< "Input channel count is zero (data batch shape: " << data_batch_shape << ", "
<< "channel axis is axis " << input_channel_axis_data << ").";
size_t spatial_dimension_count = data_batch_shape.size() - 2;
//
// Make sure filters: CoCiWv for some Co>0, rank of W = rank of Di.
//
NODE_VALIDATION_ASSERT(node, filters_shape.size() == 2 + spatial_dimension_count)
<< "Filter input must have rank equal to the data batch (one axis for output "
<< "channels, one axis for input channels, and the same number of spatial "
<< "dimensions as the data batch (filter input shape: " << filters_shape << ", "
<< "data batch shape: " << data_batch_shape << ").";
size_t output_channel_count = filters_shape[output_channel_axis_filters];
NODE_VALIDATION_ASSERT(node, output_channel_count != 0)
<< "Output channel count for filters is zero (filters shape: " << filters_shape << ", "
<< "output channels on axis " << output_channel_axis_filters << ").";
NODE_VALIDATION_ASSERT(node, filters_shape[input_channel_axis_filters] == input_channel_count)
<< "Input channel count for filters (" << filters_shape[input_channel_axis_filters] << ") "
<< "does not match the number of channels in the data batch (" << input_channel_count
<< ") "
<< "(filter input shape: " << filters_shape << ", filter input channels on axis "
<< input_channel_axis_filters << "; data batch shape: " << data_batch_shape
<< ", data batch channels on axis " << batch_axis_data << ").";
//
// Make sure window movement strides, window dilation strides, and data dilation strides
// have same rank as Di.
//
NODE_VALIDATION_ASSERT(node, window_movement_strides.size() == spatial_dimension_count)
<< "Rank of window movement strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (window movement strides: " << window_movement_strides
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, window_dilation_strides.size() == spatial_dimension_count)
<< "Rank of window dilation strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (window dilation strides: " << window_dilation_strides
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, data_dilation_strides.size() == spatial_dimension_count)
<< "Rank of data dilation strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (data dilation strides: " << data_dilation_strides
<< ", data batch shape: " << data_batch_shape << ").";
//
// Make sure padding-below and padding-above shapes have same rank as Di.
//
NODE_VALIDATION_ASSERT(node, padding_below.size() == spatial_dimension_count)
<< "Rank of the padding below does not match the number of spatial dimensions ("
<< spatial_dimension_count << ") in the data batch (padding below: " << padding_below
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, padding_above.size() == spatial_dimension_count)
<< "Rank of the padding above does not match the number of spatial dimensions ("
<< spatial_dimension_count << ") in the data batch (padding above: " << padding_above
<< ", data batch shape: " << data_batch_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0 after padding and dilation.
//
std::vector<ptrdiff_t> input_item_virtual_shape_signed;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, data_dilation_strides[i] != 0)
<< "Data dilation stride at spatial dimension " << i << " is zero "
<< "(data dilation strides: " << data_dilation_strides << ").";
size_t dim_size = data_batch_shape[1 + 1 + i];
size_t dilated_dim_size = (dim_size - 1) * data_dilation_strides[i] + 1;
ptrdiff_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
input_item_virtual_shape_signed.push_back(padded_dilated_dim_size);
}
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, input_item_virtual_shape_signed[i] > 0)
<< "Input dimension after padding and dilation is non-positive "
<< "at spatial axis " << i
<< " (post-padding/dilation input item shape: " << input_item_virtual_shape
<< ", data batch shape: " << data_batch_shape
<< ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above << ").";
input_item_virtual_shape.push_back(size_t(input_item_virtual_shape_signed[i]));
}
//
// Extract the physical shape Wp of the convolution window, *not* including dilation, from the filter dimensions.
// At the same time, make sure window shape dimensions are all larger than 0.
//
Shape window_physical_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
window_physical_shape.push_back(filters_shape[1 + 1 + i]);
NODE_VALIDATION_ASSERT(node, window_physical_shape[i] != 0)
<< "Filters shape at spatial dimension " << i << " is zero "
<< "(filters shape: " << filters_shape << ").";
}
//
// Compute virtual shape Wp of the convolution window, *including* dilation. At the same time, make sure all
// window dilation strides are larger than 0, and that the dilated filter fits within the spatial dimensions.
//
Shape window_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, window_dilation_strides[i] != 0)
<< "Window dilation stride at spatial dimension " << i << " is zero "
<< "(window dilation strides: " << window_dilation_strides << ").";
window_virtual_shape.push_back((window_physical_shape[i] - 1) * window_dilation_strides[i] +
1);
NODE_VALIDATION_ASSERT(node, window_virtual_shape[i] <= input_item_virtual_shape[i])
<< "Post-dilation window shape is smaller than the post-padding/dilation "
<< "input item shape at spatial dimension " << i << " (post-padding/dilation "
<< "input item shape: " << input_item_virtual_shape
<< ", data batch shape: " << data_batch_shape
<< ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above
<< ", post-dilation window shape: " << window_virtual_shape
<< ", filters shape: " << filters_shape
<< ", window dilation strides: " << window_dilation_strides;
}
//
// Construct result shape: NCoDo or CoNDo (depending on *_axis_result), checking at the same
// time that all window movement strides are larger than 0.
//
Shape result_shape(spatial_dimension_count + 2);
result_shape[batch_axis_result] = batch_size;
result_shape[output_channel_axis_result] = output_channel_count;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, window_movement_strides[i] != 0)
<< "Window movement stride at spatial dimension " << i << " is zero "
<< "(window movement strides: " << window_movement_strides << ").";
result_shape[i + 2] = ceil_div(input_item_virtual_shape[i] - window_virtual_shape[i] + 1,
window_movement_strides[i]);
}
return result_shape;
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters, const shared_ptr<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -249,14 +51,21 @@ void op::Convolution::validate_and_infer_types() ...@@ -249,14 +51,21 @@ void op::Convolution::validate_and_infer_types()
auto& filters_shape = get_input_shape(1); auto& filters_shape = get_input_shape(1);
auto& filters_et = get_input_element_type(1); auto& filters_et = get_input_element_type(1);
NODE_VALIDATION_ASSERT(this, data_batch_shape.size() >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
m_data_dilation_strides = default_strides(this, data_batch_shape); m_data_dilation_strides = default_strides(this, data_batch_shape);
} }
if (m_window_movement_strides.size() == 0) if (m_window_movement_strides.size() == 0)
{ {
m_window_movement_strides = default_strides(this, data_batch_shape); m_window_movement_strides = default_strides(this, data_batch_shape);
} }
if (m_window_dilation_strides.size() == 0) if (m_window_dilation_strides.size() == 0)
{ {
m_window_dilation_strides = default_strides(this, data_batch_shape); m_window_dilation_strides = default_strides(this, data_batch_shape);
...@@ -272,39 +81,26 @@ void op::Convolution::validate_and_infer_types() ...@@ -272,39 +81,26 @@ void op::Convolution::validate_and_infer_types()
m_padding_above = default_padding(this, data_batch_shape); m_padding_above = default_padding(this, data_batch_shape);
} }
// element::Type result_et;
// Make sure data batch and filter element types match. Shape result_shape;
//
NODE_VALIDATION_ASSERT(this, data_batch_et == filters_et) std::tie(result_et, result_shape) = infer_convolution_forward(this,
<< "Element types for data batch and filters do not match (data batch element type: " data_batch_et,
<< data_batch_et << ", filters element type: " << filters_et << ")."; filters_et,
data_batch_shape,
set_output_type(0, m_data_dilation_strides,
data_batch_et, m_padding_below,
util::infer_convolution_output_shape(this, m_padding_above,
data_batch_shape, filters_shape,
filters_shape, m_window_movement_strides,
m_window_movement_strides, m_window_dilation_strides);
m_window_dilation_strides,
m_padding_below, set_output_type(0, result_et, result_shape);
m_padding_above,
m_data_dilation_strides,
0,
1,
1,
0,
0,
1));
} }
Strides op::Convolution::default_strides(const Node* node, const Shape& data_batch_shape) Strides op::Convolution::default_strides(const Node* node, const Shape& data_batch_shape)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. NGRAPH_ASSERT(data_batch_shape.size() >= 2);
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch input must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
return Strides(data_batch_shape.size() - 2, 1); return Strides(data_batch_shape.size() - 2, 1);
} }
...@@ -326,12 +122,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -326,12 +122,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
CoordinateDiff op::Convolution::default_padding(const Node* node, const Shape& data_batch_shape) CoordinateDiff op::Convolution::default_padding(const Node* node, const Shape& data_batch_shape)
{ {
// For consistency we should throw the same error message here that we throw in the constructor. NGRAPH_ASSERT(data_batch_shape.size() >= 2);
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch input must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
return CoordinateDiff(data_batch_shape.size() - 2, 0); return CoordinateDiff(data_batch_shape.size() - 2, 0);
} }
...@@ -429,65 +220,88 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha ...@@ -429,65 +220,88 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types() void op::ConvolutionBackpropData::validate_and_infer_types()
{ {
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
// Forward Backward
// "N" axis for data batch 0 0
// "C" axis for data batch 1 1
// "Co" axis for filters 0 0
// "Ci" axis for filters 1 1
// "N" axis for output 0 0
// "C" axis for output 1 1
// Data batch x delta
// Data batch shape S_x S_o
// Filters f reverse(f) [on spatial axes]
// Filters shape S_f S_f
// Window movement strides q_x p_x
// Window dilation strides p_f p_f
// Padding below a_x (S_f - 1)p_f - a_x
// Padding above b_x (S_f - 1)p_f + ((a_x + (S_x - 1)p_x + b_x - (S_f - 1)p_f) % q_x) - b_x
// Data dilation strides p_x q_x
// Output shape S_o S_x
//
// To _validate_, we simply need to check/infer the output shape of the forward convolution,
// then check to make sure that the incoming delta has the same shape as the forward output.
//
// We will also compute and store the various parameters in the "backward" column above, since
// some backends need them. (TODO(amprocte): Is it just because of the way the reference works
// that this stuff is needed? If so, we can probably get rid of it and have conv_backprop
// reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.)
auto& filters_shape = get_input_shape(0); auto& filters_shape = get_input_shape(0);
auto& filters_et = get_input_element_type(0); auto& filters_et = get_input_element_type(0);
auto& output_delta_shape = get_input_shape(1); auto& delta_shape = get_input_shape(1);
auto& output_delta_et = get_input_element_type(1); auto& delta_et = get_input_element_type(1);
element::Type forward_result_et;
Shape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
delta_et,
filters_et,
m_data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
set_output_type(0, delta_et, m_data_batch_shape);
// //
// Make sure filter and output delta element types match. // Compute parameters needed for backprop-as-convolution.
// //
NODE_VALIDATION_ASSERT(this, output_delta_et == filters_et) size_t spatial_dim_count = delta_shape.size() - 2;
<< "Element types for filters and output delta do not match (filters element type: "
<< filters_et << ", output delta element type: " << output_delta_et << ")."; m_window_movement_strides_backward = m_data_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_dilation_strides_forward;
// Forward Backward m_data_dilation_strides_backward = m_window_movement_strides_forward;
// Window movement strides q p_x
// Window dilation strides p_f p_f m_padding_below_backward.resize(spatial_dim_count);
// Padding below a_x (S_F - 1)p_f - a_x m_padding_above_backward.resize(spatial_dim_count);
// Padding above b_x (S_f - 1)p_f + ((a_x + (S_x - 1)p_x + b_x - (S_f - 1)p_f) % q) - b_x
// Data dilation strides p_x q for (size_t i = 0; i < spatial_dim_count; i++)
for (size_t i = 0; i < m_data_batch_shape.size() - 2; i++)
{ {
m_window_movement_strides_backward.push_back(m_data_dilation_strides_forward[i]); m_padding_below_backward[i] =
m_window_dilation_strides_backward.push_back(m_window_dilation_strides_forward[i]); (filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i] -
m_padding_below_backward.push_back((filters_shape[i + 2] - 1) * m_padding_below_forward[i];
m_window_dilation_strides_forward[i] - m_padding_above_backward[i] =
m_padding_below_forward[i]);
m_padding_above_backward.push_back(
(filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i] + (filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i] +
((m_padding_below_forward[i] + ((m_padding_below_forward[i] +
(m_data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] + (m_data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] - m_padding_above_forward[i] -
(filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) % (filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]) - m_window_movement_strides_forward[i]) -
m_padding_above_forward[i]); m_padding_above_forward[i];
m_data_dilation_strides_backward.push_back(m_window_movement_strides_forward[i]);
} }
Shape inferred_convolution_output_shape =
util::infer_convolution_output_shape(this,
output_delta_shape,
filters_shape,
m_window_movement_strides_backward,
m_window_dilation_strides_backward,
m_padding_below_backward,
m_padding_above_backward,
m_data_dilation_strides_backward,
0,
1,
0,
1,
0,
1);
NODE_VALIDATION_ASSERT(this, inferred_convolution_output_shape == m_data_batch_shape)
<< "Specified data batch shape does not match the inferred data batch shape "
<< "(specified shape: " << m_data_batch_shape
<< ", inferred data batch shape: " << inferred_convolution_output_shape;
set_output_type(0, filters_et, inferred_convolution_output_shape);
} }
void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints, void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -596,62 +410,84 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( ...@@ -596,62 +410,84 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types() void op::ConvolutionBackpropFilters::validate_and_infer_types()
{ {
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
// Forward Backward
// "N" axis for data batch 0 1
// "C" axis for data batch 1 0
// "Co" axis for filters 0 0
// "Ci" axis for filters 1 1
// "N" axis for output 0 1
// "C" axis for output 1 0
// Data batch x x
// Data batch shape S_x S_x
// Filters f delta
// Filters shape S_f S_f
// Window movement strides q_x p_f
// Window dilation strides p_f q_x
// Padding below a_x a_x
// Padding above b_x b_x - (a_x + (S_x - 1)p_x + b_x - (S_f - 1)p_f) % q_x
// Data dilation strides p_x p_x
// Output shape S_o S_f
//
// To _validate_, we simply need to check/infer the output shape of the forward convolution,
// then check to make sure that the incoming delta has the same shape as the forward output.
//
// We will also compute and store the various parameters in the "backward" column above, since
// some backends need them. (TODO(amprocte): Is it just because of the way the reference works
// that this stuff is needed? If so, we can probably get rid of it and have conv_backprop
// reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.)
auto& data_batch_shape = get_input_shape(0); auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0); auto& data_batch_et = get_input_element_type(0);
auto& output_delta_shape = get_input_shape(1); auto& delta_shape = get_input_shape(1);
auto& output_delta_et = get_input_element_type(1); auto& delta_et = get_input_element_type(1);
element::Type forward_result_et;
Shape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this,
data_batch_et,
delta_et,
data_batch_shape,
m_data_dilation_strides_forward,
m_padding_below_forward,
m_padding_above_forward,
m_filters_shape,
m_window_movement_strides_forward,
m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ").";
set_output_type(0, delta_et, m_filters_shape);
// //
// Make sure data batch and output delta element types match. // Compute parameters needed for backprop-as-convolution.
// //
NODE_VALIDATION_ASSERT(this, output_delta_et == data_batch_et) size_t spatial_dim_count = delta_shape.size() - 2;
<< "Element types for data batch and output delta do not match (data batch element type: "
<< data_batch_et << ", output delta element type: " << output_delta_et << ")."; m_window_movement_strides_backward = m_window_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_movement_strides_forward;
// Forward Backward m_padding_below_backward = m_padding_below_forward;
// Window movement strides q p_f m_data_dilation_strides_backward = m_data_dilation_strides_forward;
// Window dilation strides p_f q
// Padding below a_x a_x m_padding_above_backward.resize(spatial_dim_count);
// Padding above b_x b_x - (a_x + (S_x - 1)p_x + b_x - (S_f - 1)p_f) % q
// Data dilation strides p_x p_x for (size_t i = 0; i < spatial_dim_count; i++)
for (size_t i = 0; i < m_filters_shape.size() - 2; i++)
{ {
m_window_movement_strides_backward.push_back(m_window_dilation_strides_forward[i]); m_padding_above_backward[i] =
m_window_dilation_strides_backward.push_back(m_window_movement_strides_forward[i]);
m_padding_below_backward.push_back(m_padding_below_forward[i]);
m_padding_above_backward.push_back(
m_padding_above_forward[i] - m_padding_above_forward[i] -
(m_padding_below_forward[i] + (m_padding_below_forward[i] +
(data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] + (data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] - m_padding_above_forward[i] -
(m_filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) % (m_filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]); m_window_movement_strides_forward[i];
m_data_dilation_strides_backward.push_back(m_data_dilation_strides_forward[i]);
} }
Shape inferred_convolution_output_shape =
util::infer_convolution_output_shape(this,
data_batch_shape,
output_delta_shape,
m_window_movement_strides_backward,
m_window_dilation_strides_backward,
m_padding_below_backward,
m_padding_above_backward,
m_data_dilation_strides_backward,
1,
0,
0,
1,
1,
0);
NODE_VALIDATION_ASSERT(this, inferred_convolution_output_shape == m_filters_shape)
<< "Specified filters shape does not match the inferred filters shape "
<< "(specified shape: " << m_filters_shape
<< ", inferred filters shape: " << inferred_convolution_output_shape;
set_output_type(0, data_batch_et, inferred_convolution_output_shape);
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -667,3 +503,207 @@ shared_ptr<Node> ...@@ -667,3 +503,207 @@ shared_ptr<Node>
m_padding_above_forward, m_padding_above_forward,
m_data_dilation_strides_forward); m_data_dilation_strides_forward);
} }
//
// This is a legacy function, retained because the CPU backend uses it for now.
// TODO(amprocte): Update CPU backend to use the new stuff in validation_util.hpp, and remove this
// function.
//
Shape op::util::infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result)
{
NODE_VALIDATION_ASSERT(node, batch_axis_data <= 1) << "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, input_channel_axis_data <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, input_channel_axis_filters <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, output_channel_axis_filters <= 1)
<< "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, batch_axis_result <= 1) << "(This is an internal nGraph error)";
NODE_VALIDATION_ASSERT(node, output_channel_axis_result <= 1)
<< "(This is an internal nGraph error)";
//
// Make sure data_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
//
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch input must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
size_t batch_size = data_batch_shape[batch_axis_data];
NODE_VALIDATION_ASSERT(node, batch_size != 0)
<< "Data batch size is zero (data batch shape: " << data_batch_shape << ", "
<< "batch axis is axis " << batch_axis_data << ").";
size_t input_channel_count = data_batch_shape[input_channel_axis_data];
NODE_VALIDATION_ASSERT(node, input_channel_count != 0)
<< "Input channel count is zero (data batch shape: " << data_batch_shape << ", "
<< "channel axis is axis " << input_channel_axis_data << ").";
size_t spatial_dimension_count = data_batch_shape.size() - 2;
//
// Make sure filters: CoCiWv for some Co>0, rank of W = rank of Di.
//
NODE_VALIDATION_ASSERT(node, filters_shape.size() == 2 + spatial_dimension_count)
<< "Filter input must have rank equal to the data batch (one axis for output "
<< "channels, one axis for input channels, and the same number of spatial "
<< "dimensions as the data batch (filter input shape: " << filters_shape << ", "
<< "data batch shape: " << data_batch_shape << ").";
size_t output_channel_count = filters_shape[output_channel_axis_filters];
NODE_VALIDATION_ASSERT(node, output_channel_count != 0)
<< "Output channel count for filters is zero (filters shape: " << filters_shape << ", "
<< "output channels on axis " << output_channel_axis_filters << ").";
NODE_VALIDATION_ASSERT(node, filters_shape[input_channel_axis_filters] == input_channel_count)
<< "Input channel count for filters (" << filters_shape[input_channel_axis_filters] << ") "
<< "does not match the number of channels in the data batch (" << input_channel_count
<< ") "
<< "(filter input shape: " << filters_shape << ", filter input channels on axis "
<< input_channel_axis_filters << "; data batch shape: " << data_batch_shape
<< ", data batch channels on axis " << batch_axis_data << ").";
//
// Make sure window movement strides, window dilation strides, and data dilation strides
// have same rank as Di.
//
NODE_VALIDATION_ASSERT(node, window_movement_strides.size() == spatial_dimension_count)
<< "Rank of window movement strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (window movement strides: " << window_movement_strides
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, window_dilation_strides.size() == spatial_dimension_count)
<< "Rank of window dilation strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (window dilation strides: " << window_dilation_strides
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, data_dilation_strides.size() == spatial_dimension_count)
<< "Rank of data dilation strides does not match the number of spatial dimensions ("
<< spatial_dimension_count
<< ") in the data batch (data dilation strides: " << data_dilation_strides
<< ", data batch shape: " << data_batch_shape << ").";
//
// Make sure padding-below and padding-above shapes have same rank as Di.
//
NODE_VALIDATION_ASSERT(node, padding_below.size() == spatial_dimension_count)
<< "Rank of the padding below does not match the number of spatial dimensions ("
<< spatial_dimension_count << ") in the data batch (padding below: " << padding_below
<< ", data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, padding_above.size() == spatial_dimension_count)
<< "Rank of the padding above does not match the number of spatial dimensions ("
<< spatial_dimension_count << ") in the data batch (padding above: " << padding_above
<< ", data batch shape: " << data_batch_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0 after padding and dilation.
//
std::vector<ptrdiff_t> input_item_virtual_shape_signed;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, data_dilation_strides[i] != 0)
<< "Data dilation stride at spatial dimension " << i << " is zero "
<< "(data dilation strides: " << data_dilation_strides << ").";
size_t dim_size = data_batch_shape[1 + 1 + i];
size_t dilated_dim_size = (dim_size - 1) * data_dilation_strides[i] + 1;
ptrdiff_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
input_item_virtual_shape_signed.push_back(padded_dilated_dim_size);
}
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, input_item_virtual_shape_signed[i] > 0)
<< "Input dimension after padding and dilation is non-positive "
<< "at spatial axis " << i
<< " (post-padding/dilation input item shape: " << input_item_virtual_shape
<< ", data batch shape: " << data_batch_shape
<< ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above << ").";
input_item_virtual_shape.push_back(size_t(input_item_virtual_shape_signed[i]));
}
//
// Extract the physical shape Wp of the convolution window, *not* including dilation, from the filter dimensions.
// At the same time, make sure window shape dimensions are all larger than 0.
//
Shape window_physical_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
window_physical_shape.push_back(filters_shape[1 + 1 + i]);
NODE_VALIDATION_ASSERT(node, window_physical_shape[i] != 0)
<< "Filters shape at spatial dimension " << i << " is zero "
<< "(filters shape: " << filters_shape << ").";
}
//
// Compute virtual shape Wp of the convolution window, *including* dilation. At the same time, make sure all
// window dilation strides are larger than 0, and that the dilated filter fits within the spatial dimensions.
//
Shape window_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, window_dilation_strides[i] != 0)
<< "Window dilation stride at spatial dimension " << i << " is zero "
<< "(window dilation strides: " << window_dilation_strides << ").";
window_virtual_shape.push_back((window_physical_shape[i] - 1) * window_dilation_strides[i] +
1);
NODE_VALIDATION_ASSERT(node, window_virtual_shape[i] <= input_item_virtual_shape[i])
<< "Post-dilation window shape is smaller than the post-padding/dilation "
<< "input item shape at spatial dimension " << i << " (post-padding/dilation "
<< "input item shape: " << input_item_virtual_shape
<< ", data batch shape: " << data_batch_shape
<< ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above
<< ", post-dilation window shape: " << window_virtual_shape
<< ", filters shape: " << filters_shape
<< ", window dilation strides: " << window_dilation_strides;
}
//
// Construct result shape: NCoDo or CoNDo (depending on *_axis_result), checking at the same
// time that all window movement strides are larger than 0.
//
Shape result_shape(spatial_dimension_count + 2);
result_shape[batch_axis_result] = batch_size;
result_shape[output_channel_axis_result] = output_channel_count;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(node, window_movement_strides[i] != 0)
<< "Window movement stride at spatial dimension " << i << " is zero "
<< "(window movement strides: " << window_movement_strides << ").";
result_shape[i + 2] = ceil_div(input_item_virtual_shape[i] - window_virtual_shape[i] + 1,
window_movement_strides[i]);
}
return result_shape;
}
...@@ -356,6 +356,9 @@ namespace ngraph ...@@ -356,6 +356,9 @@ namespace ngraph
namespace util namespace util
{ {
// This is a legacy function, retained because the CPU backend uses it for now.
// TODO: Update CPU backend to use the new stuff in validation_util.hpp, and remove
// this function.
Shape infer_convolution_output_shape(const Node* node, Shape infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape, const Shape& data_batch_shape,
const Shape& filters_shape, const Shape& filters_shape,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/select_and_scatter.hpp" #include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -39,14 +40,14 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -39,14 +40,14 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO(amprocte): This code is now *exactly* the same as AvgPool::validate_and_infer_types(),
// except that for AvgPool we also have an optional check that the pooling window is never
// entirely in the padding. Should unify in a utility function, but not sure where it belongs
// at this juncture.
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2) if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
{ {
m_window_movement_strides = Strides(arg_shape.size() - 2, 1); m_window_movement_strides = Strides(arg_shape.size() - 2, 1);
...@@ -62,105 +63,20 @@ void op::MaxPool::validate_and_infer_types() ...@@ -62,105 +63,20 @@ void op::MaxPool::validate_and_infer_types()
m_padding_above = Shape(arg_shape.size() - 2, 0); m_padding_above = Shape(arg_shape.size() - 2, 0);
} }
// // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// Make sure batch size and channel count are not zero, and that we have at least one spatial // now still take Shape (no negative padding).
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
// CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape set_output_type(0,
<< ")."; get_input_element_type(0),
infer_batched_pooling_forward(this,
size_t batch_size = arg_shape[0]; arg_shape,
NODE_VALIDATION_ASSERT(this, batch_size != 0) padding_below,
<< "Data batch size is zero (data input shape: " << arg_shape << ")."; padding_above,
m_window_shape,
size_t channel_count = arg_shape[1]; m_window_movement_strides,
NODE_VALIDATION_ASSERT(this, channel_count != 0) true));
<< "Channel count is zero (data input shape: " << arg_shape << ").";
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
}
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
}
//
// Make sure the pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
//
// Construct result shape: NCDo.
//
Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size;
result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
set_output_type(0, get_input_element_type(0), result_shape);
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
...@@ -204,121 +120,33 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -204,121 +120,33 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types() void op::MaxPoolBackprop::validate_and_infer_types()
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == get_input_element_type(1)) auto forward_arg_et = get_input_element_type(0);
<< "Data input and delta element types do not match (data input element type: " auto& forward_arg_shape = get_input_shape(0);
<< get_input_element_type(0) << ", delta element type: " << get_input_element_type(1) auto delta_et = get_input_element_type(1);
<< ")."; auto& delta_shape = get_input_shape(1);
// NODE_VALIDATION_ASSERT(this, forward_arg_et == delta_et)
// TODO(amprocte): de-duplicate almost all the rest of this code from << "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et
// MaxPool::validate_and_infer_types(). << ") do not match.";
//
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
auto& arg_shape = get_input_shape(0); // now still take Shape (no negative padding).
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
// CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
// Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0). Shape forward_result_shape = infer_batched_pooling_forward(this,
// forward_arg_shape,
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) padding_below,
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape padding_above,
<< ")."; m_window_shape,
m_window_movement_strides,
size_t batch_size = arg_shape[0]; true);
NODE_VALIDATION_ASSERT(this, batch_size != 0)
<< "Data batch size is zero (data input shape: " << arg_shape << ")."; NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
<< "Inferred forward output shape does not match delta shape (inferred forward output "
size_t channel_count = arg_shape[1]; << "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
NODE_VALIDATION_ASSERT(this, channel_count != 0)
<< "Channel count is zero (data input shape: " << arg_shape << ")."; set_output_type(0, get_input_element_type(0), forward_arg_shape);
size_t spatial_dimension_count = arg_shape.size() - 2;
//
// Make sure window shape, window movement strides, and padding have same rank as Di.
//
NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
<< "Window shape rank does not match number of spatial dimensions (window shape: "
<< m_window_shape << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
<< "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
<< "Below-padding rank does not match number of spatial dimensions (padding below: "
<< m_padding_below << ", data input shape: " << arg_shape << ").";
NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
<< "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
//
// Extract input item shape Di and make sure all dimensions are larger than 0.
//
Shape input_item_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size);
}
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
<< "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
}
//
// Make sure window shape dimensions are all larger than 0.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
<< "Window shape dimension " << i
<< " has zero length (window shape: " << m_window_shape << ").";
}
//
// Make sure the pooling window fits within the spatial dimensions.
//
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
<< "Window shape after padding is larger than the spatial dimensions (window shape: "
<< m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
<< ").";
}
//
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
//
Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++)
{
NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
<< "Window movement strides dimension " << i
<< " has zero length (window movement strides: " << m_window_movement_strides << ").";
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i]));
}
//
// Construct result shape: NCDo.
//
Shape result_shape(1 + 1 + spatial_dimension_count);
result_shape[0] = batch_size;
result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
NODE_VALIDATION_ASSERT(this, get_input_shape(1) == result_shape)
<< "Forward result shape and delta shape do not match (forward result shape: "
<< result_shape << ", delta shape: " << get_input_shape(1) << ").";
set_output_type(0, get_input_element_type(0), arg_shape);
} }
shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/validation_util.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
//
// Infers the output shape of a windowed reduction operation, where the data may be dilated and/or
// padded, and the reduction window may be strided and/or dilated.
//
Shape ngraph::infer_windowed_reduction_output_shape(const Node* node,
const Shape& data_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
const Strides& window_dilation,
bool is_window_all_in_padding_allowed)
{
NODE_VALIDATION_ASSERT(node, data_shape.size() == data_dilation.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the data dilation (" << data_dilation << ").";
NODE_VALIDATION_ASSERT(node, data_shape.size() == data_padding_below.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the data padding below (" << data_padding_below << ").";
NODE_VALIDATION_ASSERT(node, data_shape.size() == data_padding_above.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the data padding above (" << data_padding_above << ").";
NODE_VALIDATION_ASSERT(node, data_shape.size() == window_shape.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the window shape (" << window_shape << ").";
NODE_VALIDATION_ASSERT(node, data_shape.size() == window_strides.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the window strides (" << window_strides << ").";
NODE_VALIDATION_ASSERT(node, data_shape.size() == window_dilation.size())
<< "Data shape (" << data_shape << ") does not have same rank as "
<< "the window dilation (" << window_dilation << ").";
Shape output_shape(data_shape.size());
for (size_t i = 0; i < data_shape.size(); i++)
{
NODE_VALIDATION_ASSERT(node, data_dilation[i] > 0)
<< "Data dilation (" << data_dilation << ") has zero dimension at axis " << i << ".";
NODE_VALIDATION_ASSERT(node, window_strides[i] > 0)
<< "Window strides (" << window_strides << ") has zero dimension at axis " << i << ".";
NODE_VALIDATION_ASSERT(node, window_dilation[i] > 0)
<< "Window dilation (" << window_dilation << ") has zero dimension at axis " << i
<< ".";
ptrdiff_t data_padded_dilated_dim =
(ptrdiff_t(data_dilation[i]) * (ptrdiff_t(data_shape[i]) - 1)) + 1 +
data_padding_below[i] + data_padding_above[i];
ptrdiff_t window_dilated_dim =
ptrdiff_t(window_dilation[i]) * (ptrdiff_t(window_shape[i]) - 1) + 1;
NODE_VALIDATION_ASSERT(node, data_padded_dilated_dim > 0)
<< "Data shape after padding and dilation has dimension less than 1 (dim: "
<< data_padded_dilated_dim << ") at axis " << i << ".";
NODE_VALIDATION_ASSERT(node, window_dilated_dim > 0)
<< "Window after dilation has dimension less than 1 (dim: " << window_dilated_dim
<< ") at axis " << i << ".";
NODE_VALIDATION_ASSERT(node, window_dilated_dim <= data_padded_dilated_dim)
<< "Window after dilation has dimension (dim: " << window_dilated_dim
<< ") larger than the data shape after padding (dim: " << data_padded_dilated_dim
<< ") at axis " << i << ".";
NODE_VALIDATION_ASSERT(node,
is_window_all_in_padding_allowed ||
(window_dilated_dim >= data_padding_below[i] &&
window_dilated_dim >= data_padding_above[i]))
<< "Window after dilation is sometimes entirely in the padding area for axis " << i
<< "(dilated window dimension: " << window_dilated_dim
<< ", padding below dimension: " << data_padding_below[i]
<< ", padding above dimension: " << data_padding_above[i] << ") and this is not "
<< "allowed.";
size_t output_dim = ceil_div(
size_t(data_padded_dilated_dim) - size_t(window_dilated_dim) + 1, window_strides[i]);
output_shape[i] = output_dim;
}
return output_shape;
}
//
// Infers the output batch shape and element type for convolution fprop.
//
std::tuple<element::Type, Shape>
ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const Shape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation)
{
NODE_VALIDATION_ASSERT(node, et_batch == et_filters)
<< "Element types for data batch and filters do not match (data batch element type: "
<< et_batch << ", filters element type: " << et_filters << ").";
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
NODE_VALIDATION_ASSERT(node, filters_shape.size() >= 3)
<< "Filters must have rank of at least 3 (one output-channel axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(filters shape: " << filters_shape << ").";
size_t batch_size = data_batch_shape[0];
size_t data_channel_count = data_batch_shape[1];
Shape data_spatial_shape(data_batch_shape.begin() + 2, data_batch_shape.end());
size_t filter_output_channel_count = filters_shape[0];
size_t filter_input_channel_count = filters_shape[1];
Shape filter_spatial_shape(filters_shape.begin() + 2, filters_shape.end());
NODE_VALIDATION_ASSERT(node, batch_size > 0) << "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, data_channel_count > 0) << "Data batch channel count is zero.";
NODE_VALIDATION_ASSERT(node, data_channel_count == filter_input_channel_count)
<< "Data batch channel count (" << data_channel_count << ") does not match filter input "
<< "channel count (" << filter_input_channel_count << ").";
NODE_VALIDATION_ASSERT(node, filter_output_channel_count > 0)
<< "Filter output channel count is zero.";
Shape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape,
data_dilation,
data_padding_below,
data_padding_above,
filter_spatial_shape,
filter_strides,
filter_dilation,
true);
Shape batch_output_shape(data_batch_shape.size());
batch_output_shape[0] = batch_size;
batch_output_shape[1] = filter_output_channel_count;
std::copy(data_output_shape.begin(), data_output_shape.end(), batch_output_shape.begin() + 2);
return std::make_tuple(et_batch, batch_output_shape);
}
//
// Infers the output batch shape and element type for batched pooling fprop.
//
Shape ngraph::infer_batched_pooling_forward(const Node* node,
const Shape& data_batch_shape,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed)
{
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
size_t spatial_dimension_count = data_batch_shape.size() - 2;
NODE_VALIDATION_ASSERT(node, data_padding_below.size() == spatial_dimension_count)
<< "Data padding below (" << data_padding_below << ") does not have required rank ("
<< spatial_dimension_count << ").";
NODE_VALIDATION_ASSERT(node, data_padding_above.size() == spatial_dimension_count)
<< "Data padding above (" << data_padding_above << ") does not have required rank ("
<< spatial_dimension_count << ").";
NODE_VALIDATION_ASSERT(node, window_shape.size() == spatial_dimension_count)
<< "Window shape (" << window_shape << ") does not have required rank ("
<< spatial_dimension_count << ").";
NODE_VALIDATION_ASSERT(node, window_strides.size() == spatial_dimension_count)
<< "Window shape (" << window_strides << ") does not have required rank ("
<< spatial_dimension_count << ").";
size_t batch_size = data_batch_shape[0];
size_t channel_count = data_batch_shape[1];
Shape data_spatial_shape(data_batch_shape.begin() + 2, data_batch_shape.end());
NODE_VALIDATION_ASSERT(node, batch_size > 0) << "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, channel_count > 0) << "Channel count is zero.";
// For pooling ops we don't need dilation, so we fill in the identity value (all 1).
Strides data_dilation(spatial_dimension_count, 1);
Strides window_dilation(spatial_dimension_count, 1);
Shape data_output_shape =
infer_windowed_reduction_output_shape(node,
data_spatial_shape,
data_dilation,
data_padding_below,
data_padding_above,
window_shape,
window_strides,
window_dilation,
is_window_all_in_padding_allowed);
Shape batch_output_shape(data_batch_shape.size());
batch_output_shape[0] = batch_size;
batch_output_shape[1] = channel_count;
std::copy(data_output_shape.begin(), data_output_shape.end(), batch_output_shape.begin() + 2);
return batch_output_shape;
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <tuple>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
Shape infer_windowed_reduction_output_shape(const Node* node,
const Shape& data_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape>
infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const Shape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
Shape infer_batched_pooling_forward(const Node* node,
const Shape& data_batch_shape,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed);
}
...@@ -2915,7 +2915,9 @@ TEST(type_prop, conv_invalid_0d_input) ...@@ -2915,7 +2915,9 @@ TEST(type_prop, conv_invalid_0d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch input must have rank of at least 3")); std::string("Data batch must have rank of at least 3 "
"(one batch axis, one input-channel axis, "
"and at least one spatial dimension)"));
} }
catch (...) catch (...)
{ {
...@@ -2938,7 +2940,9 @@ TEST(type_prop, conv_invalid_1d_input) ...@@ -2938,7 +2940,9 @@ TEST(type_prop, conv_invalid_1d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch input must have rank of at least 3")); std::string("Data batch must have rank of at least 3 "
"(one batch axis, one input-channel axis, "
"and at least one spatial dimension)"));
} }
catch (...) catch (...)
{ {
...@@ -2961,7 +2965,9 @@ TEST(type_prop, conv_invalid_2d_input) ...@@ -2961,7 +2965,9 @@ TEST(type_prop, conv_invalid_2d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch input must have rank of at least 3")); std::string("Data batch must have rank of at least 3 "
"(one batch axis, one input-channel axis, "
"and at least one spatial dimension)"));
} }
catch (...) catch (...)
{ {
...@@ -2983,7 +2989,7 @@ TEST(type_prop, conv_invalid_0_batch_size) ...@@ -2983,7 +2989,7 @@ TEST(type_prop, conv_invalid_0_batch_size)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch size is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3005,7 +3011,7 @@ TEST(type_prop, conv_invalid_0_input_channels) ...@@ -3005,7 +3011,7 @@ TEST(type_prop, conv_invalid_0_input_channels)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input channel count is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch channel count is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3028,7 +3034,8 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many) ...@@ -3028,7 +3034,8 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Filter input must have rank equal to the data batch")); std::string("Data shape (Shape{10, 10}) does not have same rank as "
"the window shape (Shape{3, 3, 3})"));
} }
catch (...) catch (...)
{ {
...@@ -3051,7 +3058,8 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few) ...@@ -3051,7 +3058,8 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Filter input must have rank equal to the data batch")); std::string("Data shape (Shape{10, 10}) does not have "
"same rank as the window shape (Shape{3})"));
} }
catch (...) catch (...)
{ {
...@@ -3073,7 +3081,7 @@ TEST(type_prop, conv_invalid_0_output_channels) ...@@ -3073,7 +3081,7 @@ TEST(type_prop, conv_invalid_0_output_channels)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Output channel count for filters is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3095,9 +3103,10 @@ TEST(type_prop, conv_invalid_input_channel_mismatch) ...@@ -3095,9 +3103,10 @@ TEST(type_prop, conv_invalid_input_channel_mismatch)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Input channel count for filters (3) does not match the " error.what(),
"number of channels in the data batch (2)")); std::string(
"Data batch channel count (2) does not match filter input channel count (3)"));
} }
catch (...) catch (...)
{ {
...@@ -3119,10 +3128,9 @@ TEST(type_prop, conv_invalid_movement_stride_rank) ...@@ -3119,10 +3128,9 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape (Shape{10, 10}) does not have same rank as "
std::string( "the window strides (Strides{2, 3, 8})"));
"Rank of window movement strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3144,10 +3152,9 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank) ...@@ -3144,10 +3152,9 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape (Shape{10, 10}) does not have same rank as "
std::string( "the window dilation (Strides{2, 3, 8})"));
"Rank of window dilation strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3175,10 +3182,9 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank) ...@@ -3175,10 +3182,9 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape (Shape{10, 10}) does not have same rank as "
std::string( "the data dilation (Strides{2, 3, 8})"));
"Rank of data dilation strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3205,10 +3211,9 @@ TEST(type_prop, conv_invalid_padding_below_rank) ...@@ -3205,10 +3211,9 @@ TEST(type_prop, conv_invalid_padding_below_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape (Shape{10, 10}) does not have same rank as "
std::string( "the data padding below (CoordinateDiff{0, 0, 0})"));
"Rank of the padding below does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3235,10 +3240,9 @@ TEST(type_prop, conv_invalid_padding_above_rank) ...@@ -3235,10 +3240,9 @@ TEST(type_prop, conv_invalid_padding_above_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape (Shape{10, 10}) does not have same rank as "
std::string( "the data padding above (CoordinateDiff{0, 0, 0})"));
"Rank of the padding above does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3255,8 +3259,8 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding) ...@@ -3255,8 +3259,8 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding)
{ {
auto conv = make_shared<op::Convolution>(param0, auto conv = make_shared<op::Convolution>(param0,
param1, param1,
Strides{0, 0}, Strides{1, 1},
Strides{0, 0}, Strides{1, 1},
CoordinateDiff{-4, 0}, CoordinateDiff{-4, 0},
CoordinateDiff{-7, 0}); CoordinateDiff{-7, 0});
...@@ -3265,9 +3269,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding) ...@@ -3265,9 +3269,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape after padding and dilation has dimension less "
std::string("Input dimension after padding and dilation is non-positive")); "than 1 (dim: -1) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3284,8 +3288,8 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding) ...@@ -3284,8 +3288,8 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding)
{ {
auto conv = make_shared<op::Convolution>(param0, auto conv = make_shared<op::Convolution>(param0,
param1, param1,
Strides{0, 0}, Strides{1, 1},
Strides{0, 0}, Strides{1, 1},
CoordinateDiff{-4, 0}, CoordinateDiff{-4, 0},
CoordinateDiff{-6, 0}); CoordinateDiff{-6, 0});
...@@ -3294,9 +3298,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding) ...@@ -3294,9 +3298,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape after padding and dilation has dimension less "
std::string("Input dimension after padding and dilation is non-positive")); "than 1 (dim: 0) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3318,9 +3322,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_0) ...@@ -3318,9 +3322,9 @@ TEST(type_prop, conv_invalid_input_spatial_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape after padding and dilation has "
std::string("Input dimension after padding and dilation is non-positive")); "dimension less than 1 (dim: 0) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3342,8 +3346,9 @@ TEST(type_prop, conv_invalid_window_size_0) ...@@ -3342,8 +3346,9 @@ TEST(type_prop, conv_invalid_window_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Filters shape at spatial dimension 1 is zero")); error.what(),
std::string("Window after dilation has dimension less than 1 (dim: 0) at axis 1"));
} }
catch (...) catch (...)
{ {
...@@ -3365,8 +3370,9 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0) ...@@ -3365,8 +3370,9 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Window dilation stride at spatial dimension 1 is zero")); error.what(),
std::string("Window dilation (Strides{2, 0}) has zero dimension at axis 1"));
} }
catch (...) catch (...)
{ {
...@@ -3394,8 +3400,9 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0) ...@@ -3394,8 +3400,9 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Data dilation stride at spatial dimension 1 is zero")); error.what(),
std::string("Data dilation (Strides{2, 0}) has zero dimension at axis 1"));
} }
catch (...) catch (...)
{ {
...@@ -3418,8 +3425,8 @@ TEST(type_prop, conv_invalid_dilated_window_too_large) ...@@ -3418,8 +3425,8 @@ TEST(type_prop, conv_invalid_dilated_window_too_large)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Post-dilation window shape is smaller than the " std::string("Window after dilation has dimension (dim: 9) larger than "
"post-padding/dilation input item shape")); "the data shape after padding (dim: 8) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3441,8 +3448,9 @@ TEST(type_prop, conv_invalid_movement_stride_0) ...@@ -3441,8 +3448,9 @@ TEST(type_prop, conv_invalid_movement_stride_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Window movement stride at spatial dimension 0 is zero")); error.what(),
std::string("Window strides (Strides{0, 1}) has zero dimension at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3636,7 +3644,7 @@ TEST(type_prop, max_pool_invalid_0_batch_size) ...@@ -3636,7 +3644,7 @@ TEST(type_prop, max_pool_invalid_0_batch_size)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch size is zero")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3682,7 +3690,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -3682,7 +3690,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Window shape rank does not match number of spatial dimensions")); std::string("Window shape (Shape{3, 3, 3}) does not have required rank (2)"));
} }
catch (...) catch (...)
{ {
...@@ -3705,8 +3713,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -3705,8 +3713,7 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(), std::string("Window shape (Shape{3}) does not have required rank (2)"));
std::string("Window shape rank does not match number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3731,7 +3738,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank) ...@@ -3731,7 +3738,7 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Window movement stride rank does not match number of spatial dimensions")); std::string("Window shape (Strides{2, 3, 8}) does not have required rank (2)"));
} }
catch (...) catch (...)
{ {
...@@ -3753,9 +3760,9 @@ TEST(type_prop, max_pool_invalid_input_data_size_0) ...@@ -3753,9 +3760,9 @@ TEST(type_prop, max_pool_invalid_input_data_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Data shape after padding and dilation has "
std::string("Data input spatial dimension 0 has zero length even after padding")); "dimension less than 1 (dim: 0) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3777,7 +3784,9 @@ TEST(type_prop, max_pool_invalid_window_size_0) ...@@ -3777,7 +3784,9 @@ TEST(type_prop, max_pool_invalid_window_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Window shape dimension 1 has zero length")); EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Window after dilation has dimension less than 1 (dim: 0) at axis 1"));
} }
catch (...) catch (...)
{ {
...@@ -3799,9 +3808,9 @@ TEST(type_prop, max_pool_invalid_dilated_too_large) ...@@ -3799,9 +3808,9 @@ TEST(type_prop, max_pool_invalid_dilated_too_large)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Window after dilation has dimension (dim: 9) larger than "
std::string("Window shape after padding is larger than the spatial dimensions")); "the data shape after padding (dim: 8) at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -3824,8 +3833,9 @@ TEST(type_prop, max_pool_invalid_movement_stride_0) ...@@ -3824,8 +3833,9 @@ TEST(type_prop, max_pool_invalid_movement_stride_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Window movement strides dimension 0 has zero length")); error.what(),
std::string("Window strides (Strides{0, 1}) has zero dimension at axis 0"));
} }
catch (...) catch (...)
{ {
...@@ -5927,7 +5937,7 @@ TEST(type_prop, avg_pool_invalid_0_batch_size) ...@@ -5927,7 +5937,7 @@ TEST(type_prop, avg_pool_invalid_0_batch_size)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Data batch size is zero"); EXPECT_HAS_SUBSTRING(error.what(), "Batch size is zero");
} }
catch (...) catch (...)
{ {
...@@ -5972,7 +5982,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -5972,7 +5982,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_many)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window shape rank does not match number of spatial dimensions"); "Window shape (Shape{3, 3, 3}) does not have required rank (2)");
} }
catch (...) catch (...)
{ {
...@@ -5995,7 +6005,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -5995,7 +6005,7 @@ TEST(type_prop, avg_pool_invalid_wrong_number_of_window_dimensions_too_few)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window shape rank does not match number of spatial dimensions"); "Window shape (Shape{3}) does not have required rank (2)");
} }
catch (...) catch (...)
{ {
...@@ -6018,9 +6028,8 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank) ...@@ -6018,9 +6028,8 @@ TEST(type_prop, avg_pool_invalid_movement_stride_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), "Window shape (Strides{2, 3, 8}) does not have required rank (2)");
"Window movement stride rank does not match number of spatial dimensions");
} }
catch (...) catch (...)
{ {
...@@ -6046,8 +6055,9 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank) ...@@ -6046,8 +6055,9 @@ TEST(type_prop, avg_pool_invalid_padding_below_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
"Below-padding rank does not match number of spatial dimensions"); error.what(),
"Data padding below (CoordinateDiff{1, 2, 3}) does not have required rank (2)");
} }
catch (...) catch (...)
{ {
...@@ -6073,8 +6083,9 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank) ...@@ -6073,8 +6083,9 @@ TEST(type_prop, avg_pool_invalid_padding_above_rank)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
"Above-padding rank does not match number of spatial dimensions"); error.what(),
"Data padding above (CoordinateDiff{1, 2, 3}) does not have required rank (2");
} }
catch (...) catch (...)
{ {
...@@ -6096,8 +6107,9 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0) ...@@ -6096,8 +6107,9 @@ TEST(type_prop, avg_pool_invalid_input_item_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(
"Data input spatial dimension 0 has zero length even after padding"); error.what(),
"Data shape after padding and dilation has dimension less than 1 (dim: 0) at axis 0");
} }
catch (...) catch (...)
{ {
...@@ -6119,7 +6131,8 @@ TEST(type_prop, avg_pool_invalid_window_size_0) ...@@ -6119,7 +6131,8 @@ TEST(type_prop, avg_pool_invalid_window_size_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Window shape dimension 1 has zero length"); EXPECT_HAS_SUBSTRING(error.what(),
"Window after dilation has dimension less than 1 (dim: 0) at axis 1");
} }
catch (...) catch (...)
{ {
...@@ -6142,7 +6155,8 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large) ...@@ -6142,7 +6155,8 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
"Window shape after padding is larger than the spatial dimensions"); "Window after dilation has dimension (dim: 9) larger than the data "
"shape after padding (dim: 8) at axis 0");
} }
catch (...) catch (...)
{ {
...@@ -6150,6 +6164,20 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large) ...@@ -6150,6 +6164,20 @@ TEST(type_prop, avg_pool_invalid_dilated_too_large)
} }
} }
TEST(type_prop, avg_pool_larger_than_pre_padding_but_fits_in_post_padding)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{6, 2, 8, 8});
Shape window_shape{9, 9};
Strides window_strides{1, 1};
Shape padding_below{0, 0};
Shape padding_above{1, 1};
auto avg_pool =
make_shared<op::AvgPool>(param, window_shape, window_strides, padding_below, padding_above);
ASSERT_EQ(avg_pool->get_output_element_type(0), element::f32);
ASSERT_EQ(avg_pool->get_output_shape(0), (Shape{6, 2, 1, 1}));
}
TEST(type_prop, avg_pool_invalid_movement_stride_0) TEST(type_prop, avg_pool_invalid_movement_stride_0)
{ {
// Deduce type // Deduce type
...@@ -6165,7 +6193,8 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0) ...@@ -6165,7 +6193,8 @@ TEST(type_prop, avg_pool_invalid_movement_stride_0)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), "Window movement strides dimension 0 has zero length"); EXPECT_HAS_SUBSTRING(error.what(),
"Window strides (Strides{0, 1}) has zero dimension at axis 0");
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment