Unverified Commit e9443168 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Error messages for Convolution, BatchNorm, MaxPool (#1535)

parent c386da90
...@@ -36,77 +36,55 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -36,77 +36,55 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types() void ngraph::op::BatchNorm::validate_and_infer_types()
{ {
m_bn_input_shape = get_input_shape(INPUT); m_bn_input_shape = get_input_shape(INPUT);
if (m_bn_input_shape.size() < 2) NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
{ << "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2"); << ").";
}
if (m_bn_input_shape[1] == 0) NODE_VALIDATION_ASSERT(this, m_bn_input_shape[1] != 0)
{ << "Input argument's channel dimension must have size of at least 1 (input argument shape: "
throw ngraph_error("input tensor must have at least one channel for batch normalization"); << m_bn_input_shape << ").";
}
auto& et = get_input_element_type(INPUT); auto& et = get_input_element_type(INPUT);
auto in_size = get_input_size(); auto in_size = get_input_size();
NODE_VALIDATION_ASSERT(this, in_size == 3 || in_size == 5)
<< "Argument count must be either 3 or 5 (received argument count: " << in_size << ").";
Shape channel_shape{m_bn_input_shape[1]};
if (in_size == 3) if (in_size == 3)
{ {
set_output_size(3); set_output_size(3);
this->m_bn_mean_shape.push_back(m_bn_input_shape[1]); m_bn_mean_shape = channel_shape;
set_output_type(1, et, m_bn_mean_shape); set_output_type(1, et, m_bn_mean_shape);
this->m_bn_variance_shape.push_back(m_bn_input_shape[1]); m_bn_variance_shape = channel_shape;
set_output_type(2, et, m_bn_variance_shape); set_output_type(2, et, m_bn_variance_shape);
} }
else if (in_size == 5)
{
set_output_size(1);
}
else else
{ {
throw ngraph_error("Invalid BatchNorm args"); set_output_size(1);
} }
set_output_type(0, et, m_bn_input_shape); set_output_type(0, et, m_bn_input_shape);
Shape channel_shape{m_bn_input_shape[1]};
const char* input_names[]{"gamma", "beta", "input", "mean", "variance"}; const char* input_names[]{"gamma", "beta", "input", "mean", "variance"};
for (size_t i = 0; i < get_input_size(); i++) for (size_t i = 0; i < get_input_size(); i++)
{ {
if (i == 2) if (i == INPUT)
{ {
continue; continue;
} }
if (get_input_element_type(i) != et)
{
std::stringstream err_msg;
err_msg << "The element type " << get_input_element_type(i) << " of input "
<< input_names[i] << " isn't equal to the input data's type " << et;
throw ngraph_error(err_msg.str());
}
if (get_input_shape(i) != channel_shape)
{
std::stringstream err_msg;
err_msg << "The shape " << get_input_shape(i) << " of " << input_names[i]
<< " isn't equal to input channel's shape " << channel_shape;
throw ngraph_error(err_msg.str());
}
}
for (size_t index = 0; index < get_input_size(); index++) NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
{ << "Element type of " << input_names[i] << " (" << get_input_element_type(i)
if (index != INPUT && get_input_shape(index).size() != 1) << ") is not equal to the element type of input (" << et << ").";
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
if (index != INPUT && get_input_shape(index)[0] != m_bn_input_shape[1]) NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
{ << "Shape of " << input_names[i] << " must match the channel dimension of the "
auto err_msg = std::string(input_names[index]) + << "input data (expected shape: " << channel_shape << ", actual shape of "
" shape should match the input channel size (" + << input_names[i] << ": " << get_input_shape(i)
std::to_string(m_bn_input_shape[1]) + ",)"; << ", shape of input: " << m_bn_input_shape << ").";
throw ngraph_error(err_msg.c_str());
}
} }
} }
...@@ -127,14 +105,19 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -127,14 +105,19 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const
{ {
if (this->m_training) check_new_args_count(this, new_args);
if (m_training)
{ {
// FIXME(amprocte): is this redundant?
NODE_VALIDATION_ASSERT(this, new_args.size() == 3 || new_args.size() == 5);
if (new_args.size() == 3) if (new_args.size() == 3)
{ {
return std::make_shared<BatchNorm>( return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2)); m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
} }
else if (new_args.size() == 5) else
{ {
return std::make_shared<BatchNorm>(m_epsilon, return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0), new_args.at(0),
...@@ -144,17 +127,11 @@ std::shared_ptr<ngraph::Node> ...@@ -144,17 +127,11 @@ std::shared_ptr<ngraph::Node>
new_args.at(4), new_args.at(4),
true); true);
} }
else
{
throw ngraph_error("Incorrect number of new arguments");
}
} }
else else
{ {
if (new_args.size() != 5) NODE_VALIDATION_ASSERT(this, new_args.size() == 5);
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(m_epsilon, return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0), new_args.at(0),
new_args.at(1), new_args.at(1),
...@@ -183,45 +160,37 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types() ...@@ -183,45 +160,37 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{ {
set_output_size(3); set_output_size(3);
if (get_input_shape(INPUT).size() != 4) NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
{ << "Input data shape is not a 4D tensor (input data shape: " << get_input_shape(INPUT)
throw ngraph_error("Input expected to be a 4D tensor"); << ").";
}
auto et = get_input_element_type(INPUT); auto et = get_input_element_type(INPUT);
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"}; const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"};
for (size_t i = 0; i < get_input_size(); i++) Shape channel_shape{get_input_shape(INPUT)[1]};
{
if (get_input_element_type(i) != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
Shape channel_shape{get_input_shape(INPUT).at(1)};
for (size_t i = 0; i < get_input_size(); i++) for (size_t i = 0; i < get_input_size(); i++)
{ {
if (i == 2 || i == 5) // don't check input and delta NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
<< "Element type of " << input_names[i] << " (" << get_input_element_type(i)
<< ") is not equal to the element type of input (" << et << ").";
// Note that the shape of delta, a special case, will be checked after the loop.
if (i == DELTA || i == INPUT)
{ {
continue; continue;
} }
if (get_argument(i)->get_shape() != channel_shape) NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
{ << "Shape of " << input_names[i] << " must match the channel dimension of the "
auto err_msg = std::string("The shape of ") + input_names[i] + << "input data (expected shape: " << channel_shape << ", actual shape of "
" isn't equal to input channel's shape"; << input_names[i] << ": " << get_input_shape(i)
throw ngraph_error(err_msg.c_str()); << ", shape of input: " << get_input_shape(INPUT) << ").";
}
} }
if (get_input_shape(DELTA) != get_input_shape(INPUT)) NODE_VALIDATION_ASSERT(this, get_input_shape(DELTA) == get_input_shape(INPUT))
{ << "Shape of delta must match the shape of the input data (expected shape: "
throw ngraph_error("delta shape is expected to be equal to input shape"); << get_input_shape(INPUT) << ", actual shape of delta: " << get_input_shape(DELTA) << ").";
}
set_output_type(0, get_input_element_type(INPUT), get_input_shape(INPUT)); set_output_type(0, get_input_element_type(INPUT), get_input_shape(INPUT));
set_output_type(1, get_input_element_type(GAMMA), get_input_shape(GAMMA)); set_output_type(1, get_input_element_type(GAMMA), get_input_shape(GAMMA));
...@@ -231,10 +200,7 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types() ...@@ -231,10 +200,7 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 6) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<op::BatchNormBackprop>(epsilon, return std::make_shared<op::BatchNormBackprop>(epsilon,
new_args.at(0), new_args.at(0),
new_args.at(1), new_args.at(1),
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape, Shape op::util::infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape, const Shape& filters_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
...@@ -38,140 +39,128 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape, ...@@ -38,140 +39,128 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape,
size_t input_channel_axis_filters, size_t input_channel_axis_filters,
size_t output_channel_axis_filters, size_t output_channel_axis_filters,
size_t batch_axis_result, size_t batch_axis_result,
size_t output_channel_axis_result, size_t output_channel_axis_result)
const string& error_prefix)
{ {
if (batch_axis_data > 1 || input_channel_axis_data > 1 || input_channel_axis_filters > 1 || NODE_VALIDATION_ASSERT(node, batch_axis_data <= 1) << "(This is an internal nGraph error)";
output_channel_axis_filters > 1 || batch_axis_result > 1 || output_channel_axis_result > 1) NODE_VALIDATION_ASSERT(node, input_channel_axis_data <= 1)
{ << "(This is an internal nGraph error)";
throw ngraph_error( NODE_VALIDATION_ASSERT(node, input_channel_axis_filters <= 1)
error_prefix + << "(This is an internal nGraph error)";
"Internal nGraph error: infer_convolution_output_shape: batch_axis_data, " NODE_VALIDATION_ASSERT(node, output_channel_axis_filters <= 1)
"input_channel_axis_data, input_channel_axis_filters, " << "(This is an internal nGraph error)";
"output_channel_axis_filters, " NODE_VALIDATION_ASSERT(node, batch_axis_result <= 1) << "(This is an internal nGraph error)";
"batch_axis_result, and output_channel_axis_result must all be 0 or 1."); NODE_VALIDATION_ASSERT(node, output_channel_axis_result <= 1)
} << "(This is an internal nGraph error)";
// //
// Make sure data_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0. // Make sure data_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
// //
if (data_batch_shape.size() < 3) NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
{ << "Data batch input must have rank of at least 3 (one batch axis, "
throw ngraph_error( << "one input-channel axis, and at least one spatial dimension) "
error_prefix + << "(data batch shape: " << data_batch_shape << ").";
"Convolution data batch input must have rank of at least 3 (one batch axis, one "
"input-channel axis, at least one spatial dimension).");
}
size_t batch_size = data_batch_shape[batch_axis_data]; size_t batch_size = data_batch_shape[batch_axis_data];
if (batch_size == 0) NODE_VALIDATION_ASSERT(node, batch_size != 0)
{ << "Data batch size is zero (data batch shape: " << data_batch_shape << ", "
throw ngraph_error(error_prefix + "Convolution data batch size is zero."); << "batch axis is axis " << batch_axis_data << ").";
}
size_t input_channel_count = data_batch_shape[input_channel_axis_data]; size_t input_channel_count = data_batch_shape[input_channel_axis_data];
if (input_channel_count == 0) NODE_VALIDATION_ASSERT(node, input_channel_count != 0)
{ << "Input channel count is zero (data batch shape: " << data_batch_shape << ", "
throw ngraph_error(error_prefix + "Convolution requires at least one input channel."); << "channel axis is axis " << input_channel_axis_data << ").";
}
size_t spatial_dimension_count = data_batch_shape.size() - 2; size_t spatial_dimension_count = data_batch_shape.size() - 2;
// //
// Make sure filters: CoCiWv for some Co>0, rank of W = rank of Di. // Make sure filters: CoCiWv for some Co>0, rank of W = rank of Di.
// //
if (filters_shape.size() != 2 + spatial_dimension_count) NODE_VALIDATION_ASSERT(node, filters_shape.size() == 2 + spatial_dimension_count)
{ << "Filter input must have rank equal to the data batch (one axis for output "
throw ngraph_error(error_prefix + << "channels, one axis for input channels, and the same number of spatial "
"Convolution filter input must have rank of 2 + n_spatial_dimensions."); << "dimensions as the data batch (filter input shape: " << filters_shape << ", "
} << "data batch shape: " << data_batch_shape << ").";
size_t output_channel_count = filters_shape[output_channel_axis_filters]; size_t output_channel_count = filters_shape[output_channel_axis_filters];
if (output_channel_count == 0) NODE_VALIDATION_ASSERT(node, output_channel_count != 0)
{ << "Output channel count for filters is zero (filters shape: " << filters_shape << ", "
throw ngraph_error(error_prefix + "Convolution requires at least one output channel."); << "output channels on axis " << output_channel_axis_filters << ").";
}
NODE_VALIDATION_ASSERT(node, filters_shape[input_channel_axis_filters] == input_channel_count)
if (filters_shape[input_channel_axis_filters] != input_channel_count) << "Input channel count for filters (" << filters_shape[input_channel_axis_filters] << ") "
{ << "does not match the number of channels in the data batch (" << input_channel_count
throw ngraph_error(error_prefix + << ") "
"Convolution data batch and filter input channel counts do not match."); << "(filter input shape: " << filters_shape << ", filter input channels on axis "
} << input_channel_axis_filters << "; data batch shape: " << data_batch_shape
<< ", data batch channels on axis " << batch_axis_data << ").";
// //
// Make sure window movement strides, window dilation strides, and data dilation strides // Make sure window movement strides, window dilation strides, and data dilation strides
// have same rank as Di. // have same rank as Di.
// //
if (window_movement_strides.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(node, window_movement_strides.size() == spatial_dimension_count)
{ << "Rank of window movement strides does not match the number of spatial dimensions ("
throw ngraph_error( << spatial_dimension_count
error_prefix + << ") in the data batch (window movement strides: " << window_movement_strides
"Convolution window movement stride rank does not match number of spatial dimensions."); << ", data batch shape: " << data_batch_shape << ").";
}
NODE_VALIDATION_ASSERT(node, window_dilation_strides.size() == spatial_dimension_count)
if (window_dilation_strides.size() != spatial_dimension_count) << "Rank of window dilation strides does not match the number of spatial dimensions ("
{ << spatial_dimension_count
throw ngraph_error( << ") in the data batch (window dilation strides: " << window_dilation_strides
error_prefix + << ", data batch shape: " << data_batch_shape << ").";
"Convolution window dilation stride rank does not match number of spatial dimensions.");
} NODE_VALIDATION_ASSERT(node, data_dilation_strides.size() == spatial_dimension_count)
<< "Rank of data dilation strides does not match the number of spatial dimensions ("
if (data_dilation_strides.size() != spatial_dimension_count) << spatial_dimension_count
{ << ") in the data batch (data dilation strides: " << data_dilation_strides
throw ngraph_error( << ", data batch shape: " << data_batch_shape << ").";
error_prefix +
"Convolution data dilation stride rank does not match number of spatial dimensions.");
}
// //
// Make sure padding-below and padding-above shapes have same rank as Di. // Make sure padding-below and padding-above shapes have same rank as Di.
// //
if (padding_below.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(node, padding_below.size() == spatial_dimension_count)
{ << "Rank of the padding below does not match the number of spatial dimensions ("
throw ngraph_error( << spatial_dimension_count << ") in the data batch (padding below: " << padding_below
error_prefix + << ", data batch shape: " << data_batch_shape << ").";
"Convolution padding-below rank does not match number of spatial dimensions.");
}
if (padding_above.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(node, padding_above.size() == spatial_dimension_count)
{ << "Rank of the padding above does not match the number of spatial dimensions ("
throw ngraph_error( << spatial_dimension_count << ") in the data batch (padding above: " << padding_above
error_prefix + << ", data batch shape: " << data_batch_shape << ").";
"Convolution padding-above rank does not match number of spatial dimensions.");
}
// //
// Extract input item shape Di and make sure all dimensions are larger than 0 after padding and dilation. // Extract input item shape Di and make sure all dimensions are larger than 0 after padding and dilation.
// //
Shape input_item_virtual_shape; std::vector<ptrdiff_t> input_item_virtual_shape_signed;
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (data_dilation_strides[i] == 0) NODE_VALIDATION_ASSERT(node, data_dilation_strides[i] != 0)
{ << "Data dilation stride at spatial dimension " << i << " is zero "
throw ngraph_error(error_prefix + "Convolution data dilation stride is zero."); << "(data dilation strides: " << data_dilation_strides << ").";
}
size_t dim_size = data_batch_shape[1 + 1 + i]; size_t dim_size = data_batch_shape[1 + 1 + i];
size_t dilated_dim_size = (dim_size - 1) * data_dilation_strides[i] + 1; size_t dilated_dim_size = (dim_size - 1) * data_dilation_strides[i] + 1;
ptrdiff_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i]; ptrdiff_t padded_dilated_dim_size = padding_below[i] + dilated_dim_size + padding_above[i];
if (padded_dilated_dim_size < 0) input_item_virtual_shape_signed.push_back(padded_dilated_dim_size);
{
throw ngraph_error(
error_prefix +
"Convolution input spatial dimension after padding and dilation is negative.");
} }
input_item_virtual_shape.push_back(padded_dilated_dim_size); Shape input_item_virtual_shape;
if (input_item_virtual_shape[i] == 0) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
throw ngraph_error( NODE_VALIDATION_ASSERT(node, input_item_virtual_shape_signed[i] > 0)
error_prefix + << "Input dimension after padding and dilation is non-positive "
"Convolution input spatial dimension after dilation is zero even with padding."); << "at spatial axis " << i
} << " (post-padding/dilation input item shape: " << input_item_virtual_shape
<< ", data batch shape: " << data_batch_shape
<< ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above << ").";
input_item_virtual_shape.push_back(size_t(input_item_virtual_shape_signed[i]));
} }
// //
...@@ -183,34 +172,36 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape, ...@@ -183,34 +172,36 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape,
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
window_physical_shape.push_back(filters_shape[1 + 1 + i]); window_physical_shape.push_back(filters_shape[1 + 1 + i]);
if (window_physical_shape[i] == 0) NODE_VALIDATION_ASSERT(node, window_physical_shape[i] != 0)
{ << "Filters shape at spatial dimension " << i << " is zero "
throw ngraph_error(error_prefix + "Convolution window shape has a zero-length axis."); << "(filters shape: " << filters_shape << ").";
}
} }
// //
// Compute physical shape Wp of the convolution window, *including* dilation. At the same time, make sure all // Compute virtual shape Wp of the convolution window, *including* dilation. At the same time, make sure all
// window dilation strides are larger than 0, and that the dilated filter fits within the spatial dimensions. // window dilation strides are larger than 0, and that the dilated filter fits within the spatial dimensions.
// //
Shape window_virtual_shape; Shape window_virtual_shape;
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (window_dilation_strides[i] == 0) NODE_VALIDATION_ASSERT(node, window_dilation_strides[i] != 0)
{ << "Window dilation stride at spatial dimension " << i << " is zero "
throw ngraph_error(error_prefix + "Convolution window axis dilation stride is zero."); << "(window dilation strides: " << window_dilation_strides << ").";
}
window_virtual_shape.push_back((window_physical_shape[i] - 1) * window_dilation_strides[i] + window_virtual_shape.push_back((window_physical_shape[i] - 1) * window_dilation_strides[i] +
1); 1);
if (window_virtual_shape[i] > input_item_virtual_shape[i]) NODE_VALIDATION_ASSERT(node, window_virtual_shape[i] <= input_item_virtual_shape[i])
{ << "Post-dilation window shape is smaller than the post-padding/dilation "
throw ngraph_error(error_prefix + << "input item shape at spatial dimension " << i << " (post-padding/dilation "
"Convolution window after dilation is larger than the spatial " << "input item shape: " << input_item_virtual_shape
"dimensions even with padding."); << ", data batch shape: " << data_batch_shape
} << ", data dilation strides: " << data_dilation_strides
<< ", padding below: " << padding_below << ", padding above: " << padding_above
<< ", post-dilation window shape: " << window_virtual_shape
<< ", filters shape: " << filters_shape
<< ", window dilation strides: " << window_dilation_strides;
} }
// //
...@@ -223,10 +214,10 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape, ...@@ -223,10 +214,10 @@ Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape,
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (window_movement_strides[i] == 0) NODE_VALIDATION_ASSERT(node, window_movement_strides[i] != 0)
{ << "Window movement stride at spatial dimension " << i << " is zero "
throw ngraph_error(error_prefix + "Convolution window axis movement stride is zero."); << "(window movement strides: " << window_movement_strides << ").";
}
result_shape[i + 2] = ceil_div(input_item_virtual_shape[i] - window_virtual_shape[i] + 1, result_shape[i + 2] = ceil_div(input_item_virtual_shape[i] - window_virtual_shape[i] + 1,
window_movement_strides[i]); window_movement_strides[i]);
} }
...@@ -260,38 +251,38 @@ void op::Convolution::validate_and_infer_types() ...@@ -260,38 +251,38 @@ void op::Convolution::validate_and_infer_types()
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
m_data_dilation_strides = default_strides(data_batch_shape); m_data_dilation_strides = default_strides(this, data_batch_shape);
} }
if (m_window_movement_strides.size() == 0) if (m_window_movement_strides.size() == 0)
{ {
m_window_movement_strides = default_strides(data_batch_shape); m_window_movement_strides = default_strides(this, data_batch_shape);
} }
if (m_window_dilation_strides.size() == 0) if (m_window_dilation_strides.size() == 0)
{ {
m_window_dilation_strides = default_strides(data_batch_shape); m_window_dilation_strides = default_strides(this, data_batch_shape);
} }
if (m_padding_below.size() == 0) if (m_padding_below.size() == 0)
{ {
m_padding_below = default_padding(data_batch_shape); m_padding_below = default_padding(this, data_batch_shape);
} }
if (m_padding_above.size() == 0) if (m_padding_above.size() == 0)
{ {
m_padding_above = default_padding(data_batch_shape); m_padding_above = default_padding(this, data_batch_shape);
} }
// //
// Make sure data batch and filter element types match. // Make sure data batch and filter element types match.
// //
if (data_batch_et != filters_et) NODE_VALIDATION_ASSERT(this, data_batch_et == filters_et)
{ << "Element types for data batch and filters do not match (data batch element type: "
throw ngraph_error("Convolution data batch and filter element types do not match"); << data_batch_et << ", filters element type: " << filters_et << ").";
}
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
m_window_movement_strides, m_window_movement_strides,
m_window_dilation_strides, m_window_dilation_strides,
...@@ -303,19 +294,17 @@ void op::Convolution::validate_and_infer_types() ...@@ -303,19 +294,17 @@ void op::Convolution::validate_and_infer_types()
1, 1,
0, 0,
0, 0,
1, 1));
""));
} }
Strides op::Convolution::default_strides(const Shape& data_batch_shape) Strides op::Convolution::default_strides(const Node* node, const Shape& data_batch_shape)
{ {
if (data_batch_shape.size() < 3)
{
// For consistency we should throw the same error message here that we throw in the constructor. // For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error( NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
"Convolution data batch input must have rank of at least 3 (one batch axis, one " << "Data batch input must have rank of at least 3 (one batch axis, "
"input-channel axis, at least one spatial dimension)."); << "one input-channel axis, and at least one spatial dimension) "
} << "(data batch shape: " << data_batch_shape << ").";
return Strides(data_batch_shape.size() - 2, 1); return Strides(data_batch_shape.size() - 2, 1);
} }
...@@ -335,15 +324,14 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -335,15 +324,14 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
CoordinateDiff op::Convolution::default_padding(const Shape& data_batch_shape) CoordinateDiff op::Convolution::default_padding(const Node* node, const Shape& data_batch_shape)
{ {
if (data_batch_shape.size() < 3)
{
// For consistency we should throw the same error message here that we throw in the constructor. // For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error( NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
"Convolution data batch input must have rank of at least 3 (one batch axis, one " << "Data batch input must have rank of at least 3 (one batch axis, "
"input-channel axis, at least one spatial dimension)."); << "one input-channel axis, and at least one spatial dimension) "
} << "(data batch shape: " << data_batch_shape << ").";
return CoordinateDiff(data_batch_shape.size() - 2, 0); return CoordinateDiff(data_batch_shape.size() - 2, 0);
} }
...@@ -449,11 +437,9 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -449,11 +437,9 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
// //
// Make sure filter and output delta element types match. // Make sure filter and output delta element types match.
// //
if (filters_et != output_delta_et) NODE_VALIDATION_ASSERT(this, output_delta_et == filters_et)
{ << "Element types for filters and output delta do not match (filters element type: "
throw ngraph_error( << filters_et << ", output delta element type: " << output_delta_et << ").";
"Convolution data batch backprop filter and output delta element types do not match");
}
// Forward Backward // Forward Backward
// Window movement strides q p_x // Window movement strides q p_x
...@@ -481,7 +467,8 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -481,7 +467,8 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
} }
Shape inferred_convolution_output_shape = Shape inferred_convolution_output_shape =
util::infer_convolution_output_shape(output_delta_shape, util::infer_convolution_output_shape(this,
output_delta_shape,
filters_shape, filters_shape,
m_window_movement_strides_backward, m_window_movement_strides_backward,
m_window_dilation_strides_backward, m_window_dilation_strides_backward,
...@@ -493,17 +480,12 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -493,17 +480,12 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
0, 0,
1, 1,
0, 0,
1, 1);
"In ConvolutionBackpropData: ");
// Not sure if this can ever actually happen (i.e., I think it will trip on something else NODE_VALIDATION_ASSERT(this, inferred_convolution_output_shape == m_data_batch_shape)
// inside infer_convolution_output_shape before we get here) but it seems worth checking. << "Specified data batch shape does not match the inferred data batch shape "
if (inferred_convolution_output_shape != m_data_batch_shape) << "(specified shape: " << m_data_batch_shape
{ << ", inferred data batch shape: " << inferred_convolution_output_shape;
throw ngraph_error(
"Convolution data batch backprop inferred output shape does not match "
"specified data batch shape");
}
set_output_type(0, filters_et, inferred_convolution_output_shape); set_output_type(0, filters_et, inferred_convolution_output_shape);
} }
...@@ -622,11 +604,9 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -622,11 +604,9 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
// //
// Make sure data batch and output delta element types match. // Make sure data batch and output delta element types match.
// //
if (data_batch_et != output_delta_et) NODE_VALIDATION_ASSERT(this, output_delta_et == data_batch_et)
{ << "Element types for data batch and output delta do not match (data batch element type: "
throw ngraph_error( << data_batch_et << ", output delta element type: " << output_delta_et << ").";
"Convolution filter backprop data batch and output delta element types do not match");
}
// Forward Backward // Forward Backward
// Window movement strides q p_f // Window movement strides q p_f
...@@ -651,7 +631,8 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -651,7 +631,8 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
} }
Shape inferred_convolution_output_shape = Shape inferred_convolution_output_shape =
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
output_delta_shape, output_delta_shape,
m_window_movement_strides_backward, m_window_movement_strides_backward,
m_window_dilation_strides_backward, m_window_dilation_strides_backward,
...@@ -663,17 +644,12 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -663,17 +644,12 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
0, 0,
1, 1,
1, 1,
0, 0);
"In ConvolutionBackpropFilters: ");
// Not sure if this can ever actually happen (i.e., I think it will trip on something else NODE_VALIDATION_ASSERT(this, inferred_convolution_output_shape == m_filters_shape)
// inside infer_convolution_output_shape before we get here) but it seems worth checking. << "Specified filters shape does not match the inferred filters shape "
if (inferred_convolution_output_shape != m_filters_shape) << "(specified shape: " << m_filters_shape
{ << ", inferred filters shape: " << inferred_convolution_output_shape;
throw ngraph_error(
"Convolution filter backprop inferred output shape does not match "
"specified filter shape");
}
set_output_type(0, data_batch_et, inferred_convolution_output_shape); set_output_type(0, data_batch_et, inferred_convolution_output_shape);
} }
......
...@@ -155,8 +155,8 @@ namespace ngraph ...@@ -155,8 +155,8 @@ namespace ngraph
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
private: private:
static Strides default_strides(const Shape& data_batch_shape); static Strides default_strides(const Node* node, const Shape& data_batch_shape);
static CoordinateDiff default_padding(const Shape& data_batch_shape); static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape);
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
...@@ -356,7 +356,8 @@ namespace ngraph ...@@ -356,7 +356,8 @@ namespace ngraph
namespace util namespace util
{ {
Shape infer_convolution_output_shape(const Shape& data_batch_shape, Shape infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape, const Shape& filters_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
...@@ -368,8 +369,7 @@ namespace ngraph ...@@ -368,8 +369,7 @@ namespace ngraph
size_t input_channel_axis_filters, size_t input_channel_axis_filters,
size_t output_channel_axis_filters, size_t output_channel_axis_filters,
size_t batch_axis_result, size_t batch_axis_result,
size_t output_channel_axis_result, size_t output_channel_axis_result);
const std::string& error_prefix);
} }
} }
} }
...@@ -39,87 +39,63 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -39,87 +39,63 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO(amprocte): This code is now *exactly* the same as AvgPool::validate_and_infer_types(),
// except that for AvgPool we also have an optional check that the pooling window is never
// entirely in the padding. Should unify in a utility function, but not sure where it belongs
// at this juncture.
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
if (m_inputs.size() != 1)
{
throw ngraph_error("Max-pool data batch argument must have exactly one output");
}
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
if (arg_shape.size() < 3)
{ if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
// For consistency we should throw the same error message here that we throw in the constructor.
throw ngraph_error(
"Max-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
}
Shape default_padding(arg_shape.size() - 2, 0);
if (m_padding_below.size() == 0)
{ {
m_padding_below = default_padding; m_window_movement_strides = Strides(arg_shape.size() - 2, 1);
} }
if (m_padding_above.size() == 0)
if (0 == m_padding_below.size() && arg_shape.size() > 2)
{ {
m_padding_above = default_padding; m_padding_below = Shape(arg_shape.size() - 2, 0);
} }
if (m_window_movement_strides.size() == 0) if (0 == m_padding_above.size() && arg_shape.size() > 2)
{ {
m_window_movement_strides = Strides(arg_shape.size() - 2, 1); m_padding_above = Shape(arg_shape.size() - 2, 0);
} }
// //
// Make sure arg: NCDi for some Di of rank>0, N != 0, C != 0. // Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// //
if (arg_shape.size() < 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
{ << "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
throw ngraph_error( << ").";
"Max-pool data batch input must have rank of at least 3 (one batch axis, one "
"channel axis, at least one spatial dimension).");
}
size_t batch_size = arg_shape[0]; size_t batch_size = arg_shape[0];
if (batch_size == 0) NODE_VALIDATION_ASSERT(this, batch_size != 0)
{ << "Data batch size is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Max-pool data batch size is zero.");
}
size_t channel_count = arg_shape[1]; size_t channel_count = arg_shape[1];
if (channel_count == 0) NODE_VALIDATION_ASSERT(this, channel_count != 0)
{ << "Channel count is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Max-pool requires at least one feature channel.");
}
size_t spatial_dimension_count = arg_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and padding have same rank as Di. // Make sure window shape, window movement strides, and padding have same rank as Di.
// //
if (m_window_shape.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
{ << "Window shape rank does not match number of spatial dimensions (window shape: "
throw ngraph_error( << m_window_shape << ", data input shape: " << arg_shape << ").";
"Max-pool window shape rank does not match number of spatial dimensions."); NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
} << "Window movement stride rank does not match number of spatial dimensions (window "
"movement strides: "
if (m_window_movement_strides.size() != spatial_dimension_count) << m_window_movement_strides << ", data input shape: " << arg_shape << ").";
{ NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
throw ngraph_error( << "Below-padding rank does not match number of spatial dimensions (padding below: "
"Max-pool window movement stride rank does not match number of spatial " << m_padding_below << ", data input shape: " << arg_shape << ").";
"dimensions."); NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
} << "Above-padding rank does not match number of spatial dimensions (padding above: "
<< m_padding_above << ", data input shape: " << arg_shape << ").";
if (m_padding_below.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool below-padding rank does not match number of spatial dimensions.");
}
if (m_padding_above.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool above-padding rank does not match number of spatial dimensions.");
}
// //
// Extract input item shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
...@@ -131,11 +107,14 @@ void op::MaxPool::validate_and_infer_types() ...@@ -131,11 +107,14 @@ void op::MaxPool::validate_and_infer_types()
size_t dim_size = arg_shape[1 + 1 + i]; size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i]; size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size); input_item_virtual_shape.push_back(virtual_dim_size);
}
if (virtual_dim_size == 0) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
throw ngraph_error("Max-pool input spatial dimension is zero even after padding."); NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
} << "Data input spatial dimension " << i
<< " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
} }
// //
...@@ -143,35 +122,32 @@ void op::MaxPool::validate_and_infer_types() ...@@ -143,35 +122,32 @@ void op::MaxPool::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] == 0) NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
{ << "Window shape dimension " << i
throw ngraph_error("Max-pool window shape has a zero-length axis."); << " has zero length (window shape: " << m_window_shape << ").";
}
} }
// //
// Make the max pooling window fits within the spatial dimensions. // Make sure the pooling window fits within the spatial dimensions.
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] > input_item_virtual_shape[i]) NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
{ << "Window shape after padding is larger than the spatial dimensions (window shape: "
throw ngraph_error( << m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
"Max-pool window shape is larger than the spatial dimensions even after " << ").";
"padding.");
}
} }
// //
// Compute output item shape Do, checking at the same time that all window movement strides are larger than 0. // Compute output item shape Do, checking at the same time that all window movement strides are larger than 0.
// //
Shape output_item_shape; Shape output_item_shape;
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_movement_strides[i] == 0) NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
{ << "Window movement strides dimension " << i
throw ngraph_error("Max-pool window axis movement stride is zero."); << " has zero length (window movement strides: " << m_window_movement_strides << ").";
}
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1, output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
...@@ -228,72 +204,52 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -228,72 +204,52 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types() void op::MaxPoolBackprop::validate_and_infer_types()
{ {
// -- NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == get_input_element_type(1))
// TODO: de-duplicate this code from MaxPool::MaxPool. << "Data input and delta element types do not match (data input element type: "
// -- << get_input_element_type(0) << ", delta element type: " << get_input_element_type(1)
<< ").";
if (get_input_element_type(0) != get_input_element_type(1)) //
{ // TODO(amprocte): de-duplicate almost all the rest of this code from
throw ngraph_error("Max-pool backprop: data batch and delta element types do not match."); // MaxPool::validate_and_infer_types().
} //
auto& arg_forward_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
auto& delta_shape = get_input_shape(1);
// //
// Make sure arg: NCDi for some Di of rank>0, N != 0, C != 0. // Make sure batch size and channel count are not zero, and that we have at least one spatial
// dimension (in other words, that arg has shape NCDi for some Di of rank>0, N != 0, C != 0).
// //
if (arg_forward_shape.size() < 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
{ << "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
throw ngraph_error( << ").";
"Max-pool backprop: data batch shape must have rank of at least 3 (one batch axis, "
"one channel axis, at least one spatial dimension).");
}
size_t batch_size = arg_forward_shape[0]; size_t batch_size = arg_shape[0];
if (batch_size == 0) NODE_VALIDATION_ASSERT(this, batch_size != 0)
{ << "Data batch size is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Max-pool backprop: data batch size is zero.");
}
size_t channel_count = arg_forward_shape[1]; size_t channel_count = arg_shape[1];
if (channel_count == 0) NODE_VALIDATION_ASSERT(this, channel_count != 0)
{ << "Channel count is zero (data input shape: " << arg_shape << ").";
throw ngraph_error("Max-pool backprop: requires at least one feature channel.");
}
size_t spatial_dimension_count = arg_forward_shape.size() - 2; size_t spatial_dimension_count = arg_shape.size() - 2;
// //
// Make sure window shape, window movement strides, and padding have same rank as Di. // Make sure window shape, window movement strides, and padding have same rank as Di.
// //
if (m_window_shape.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(this, m_window_shape.size() == spatial_dimension_count)
{ << "Window shape rank does not match number of spatial dimensions (window shape: "
throw ngraph_error( << m_window_shape << ", data input shape: " << arg_shape << ").";
"Max-pool backprop: window shape rank does not match number of spatial " NODE_VALIDATION_ASSERT(this, m_window_movement_strides.size() == spatial_dimension_count)
"dimensions."); << "Window movement stride rank does not match number of spatial dimensions (window "
} "movement strides: "
<< m_window_movement_strides << ", data input shape: " << arg_shape << ").";
if (m_window_movement_strides.size() != spatial_dimension_count) NODE_VALIDATION_ASSERT(this, m_padding_below.size() == spatial_dimension_count)
{ << "Below-padding rank does not match number of spatial dimensions (padding below: "
throw ngraph_error( << m_padding_below << ", data input shape: " << arg_shape << ").";
"Max-pool backprop: window movement stride rank does not match number of spatial " NODE_VALIDATION_ASSERT(this, m_padding_above.size() == spatial_dimension_count)
"dimensions."); << "Above-padding rank does not match number of spatial dimensions (padding above: "
} << m_padding_above << ", data input shape: " << arg_shape << ").";
if (m_padding_below.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: below-padding rank does not match number of spatial "
"dimensions.");
}
if (m_padding_above.size() != spatial_dimension_count)
{
throw ngraph_error(
"Max-pool backprop: above-padding rank does not match number of spatial "
"dimensions.");
}
// //
// Extract input item shape Di and make sure all dimensions are larger than 0. // Extract input item shape Di and make sure all dimensions are larger than 0.
...@@ -302,15 +258,17 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -302,15 +258,17 @@ void op::MaxPoolBackprop::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
size_t dim_size = arg_forward_shape[1 + 1 + i]; size_t dim_size = arg_shape[1 + 1 + i];
size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i]; size_t virtual_dim_size = m_padding_below[i] + dim_size + m_padding_above[i];
input_item_virtual_shape.push_back(virtual_dim_size); input_item_virtual_shape.push_back(virtual_dim_size);
}
if (virtual_dim_size == 0) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
throw ngraph_error( NODE_VALIDATION_ASSERT(this, input_item_virtual_shape[i] != 0)
"Max-pool backprop: data batch spatial dimension is zero even after padding."); << "Data input spatial dimension " << i
} << " has zero length even after padding (virtual shape of input item: "
<< input_item_virtual_shape << ").";
} }
// //
...@@ -318,23 +276,20 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -318,23 +276,20 @@ void op::MaxPoolBackprop::validate_and_infer_types()
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] == 0) NODE_VALIDATION_ASSERT(this, m_window_shape[i] != 0)
{ << "Window shape dimension " << i
throw ngraph_error("Max-pool backprop: window shape has a zero-length axis."); << " has zero length (window shape: " << m_window_shape << ").";
}
} }
// //
// Make the max pooling window fits within the spatial dimensions. // Make sure the pooling window fits within the spatial dimensions.
// //
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_shape[i] > input_item_virtual_shape[i]) NODE_VALIDATION_ASSERT(this, m_window_shape[i] <= input_item_virtual_shape[i])
{ << "Window shape after padding is larger than the spatial dimensions (window shape: "
throw ngraph_error( << m_window_shape << ", virtual shape of input item: " << input_item_virtual_shape
"Max-pool backprop: window shape is larger than the spatial dimensions even after " << ").";
"padding.");
}
} }
// //
...@@ -344,10 +299,9 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -344,10 +299,9 @@ void op::MaxPoolBackprop::validate_and_infer_types()
for (size_t i = 0; i < spatial_dimension_count; i++) for (size_t i = 0; i < spatial_dimension_count; i++)
{ {
if (m_window_movement_strides[i] == 0) NODE_VALIDATION_ASSERT(this, m_window_movement_strides[i] != 0)
{ << "Window movement strides dimension " << i
throw ngraph_error("Max-pool backprop: window axis movement stride is zero."); << " has zero length (window movement strides: " << m_window_movement_strides << ").";
}
output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1, output_item_shape.push_back(ceil_div(input_item_virtual_shape[i] - m_window_shape[i] + 1,
m_window_movement_strides[i])); m_window_movement_strides[i]));
} }
...@@ -355,17 +309,16 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -355,17 +309,16 @@ void op::MaxPoolBackprop::validate_and_infer_types()
// //
// Construct result shape: NCDo. // Construct result shape: NCDo.
// //
Shape forward_result_shape(1 + 1 + spatial_dimension_count); Shape result_shape(1 + 1 + spatial_dimension_count);
forward_result_shape[0] = batch_size; result_shape[0] = batch_size;
forward_result_shape[1] = channel_count; result_shape[1] = channel_count;
copy(output_item_shape.begin(), output_item_shape.end(), forward_result_shape.begin() + 2); copy(output_item_shape.begin(), output_item_shape.end(), result_shape.begin() + 2);
if (forward_result_shape != delta_shape) NODE_VALIDATION_ASSERT(this, get_input_shape(1) == result_shape)
{ << "Forward result shape and delta shape do not match (forward result shape: "
throw ngraph_error("Max-pool backprop: forward result shape does not match delta shape."); << result_shape << ", delta shape: " << get_input_shape(1) << ").";
}
set_output_type(0, get_input_element_type(0), arg_forward_shape); set_output_type(0, get_input_element_type(0), arg_shape);
} }
shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const shared_ptr<op::MaxPool> op::MaxPoolBackprop::get_forward_op() const
......
...@@ -109,7 +109,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -109,7 +109,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -121,8 +122,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -121,8 +122,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
...@@ -322,7 +323,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat ...@@ -322,7 +323,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape()); util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -334,8 +336,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat ...@@ -334,8 +336,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -68,7 +68,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch, ...@@ -68,7 +68,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -80,8 +81,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch, ...@@ -80,8 +81,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -49,10 +49,10 @@ TEST(type_prop, batchnorm_rank_less_than_2) ...@@ -49,10 +49,10 @@ TEST(type_prop, batchnorm_rank_less_than_2)
auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy); auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy);
FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2"; FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("input tensor to batchnorm must have tensor of at least rank 2")); std::string("Input argument must have rank of at least 2"));
} }
catch (...) catch (...)
{ {
...@@ -68,11 +68,11 @@ TEST(type_prop, batchnorm_zero_channel_check) ...@@ -68,11 +68,11 @@ TEST(type_prop, batchnorm_zero_channel_check)
auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy); auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy);
FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels"; FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("input tensor must have at least one channel for batch normalization")); std::string("Input argument's channel dimension must have size of at least 1"));
} }
catch (...) catch (...)
{ {
...@@ -91,11 +91,11 @@ TEST(type_prop, batchnorm_et_check) ...@@ -91,11 +91,11 @@ TEST(type_prop, batchnorm_et_check)
auto bc = make_shared<op::BatchNorm>(0.001, dummy_f32, dummy_f64, param); auto bc = make_shared<op::BatchNorm>(0.001, dummy_f32, dummy_f64, param);
FAIL() << "BatchNorm c-tor should throw for different element types"; FAIL() << "BatchNorm c-tor should throw for different element types";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("Element type of beta"));
std::string("The element type element::Type{64, 1, 1,double} of input beta isn't " EXPECT_HAS_SUBSTRING(error.what(),
"equal to the input data's type element::Type{32, 1, 1,float}")); std::string("is not equal to the element type of input"));
} }
catch (...) catch (...)
{ {
...@@ -114,11 +114,11 @@ TEST(type_prop, batchnorm_shape_check) ...@@ -114,11 +114,11 @@ TEST(type_prop, batchnorm_shape_check)
auto bc = make_shared<op::BatchNorm>(0.001, dummy_4, dummy_3, param); auto bc = make_shared<op::BatchNorm>(0.001, dummy_4, dummy_3, param);
FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match"; FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string( error.what(),
"The shape Shape{4} of gamma isn't equal to input channel's shape Shape{3}")); std::string("Shape of gamma must match the channel dimension of the input data"));
} }
catch (...) catch (...)
{ {
...@@ -137,9 +137,9 @@ TEST(type_prop, batchnorm_backprop_4d_check) ...@@ -137,9 +137,9 @@ TEST(type_prop, batchnorm_backprop_4d_check)
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, dummy); make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, dummy);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Input expected to be a 4D tensor")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Input data shape is not a 4D tensor"));
} }
catch (...) catch (...)
{ {
...@@ -159,10 +159,11 @@ TEST(type_prop, batchnorm_backprop_et_check) ...@@ -159,10 +159,11 @@ TEST(type_prop, batchnorm_backprop_et_check)
0.001, dummy_f32, dummy_f64, param, dummy_f32, dummy_f32, dummy_f32); 0.001, dummy_f32, dummy_f64, param, dummy_f32, dummy_f32, dummy_f32);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(), std::string("Element type of beta"));
std::string("The element type of beta isn't equal to input data's type")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("is not equal to the element type of input"));
} }
catch (...) catch (...)
{ {
...@@ -182,10 +183,11 @@ TEST(type_prop, batchnorm_backprop_shape_check) ...@@ -182,10 +183,11 @@ TEST(type_prop, batchnorm_backprop_shape_check)
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy2, param, dummy2, dummy2, dummy2); make_shared<op::BatchNormBackprop>(0.001, dummy, dummy2, param, dummy2, dummy2, dummy2);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("The shape of beta isn't equal to input channel's shape")); error.what(),
std::string("Shape of beta must match the channel dimension of the input data"));
} }
catch (...) catch (...)
{ {
...@@ -206,9 +208,10 @@ TEST(type_prop, batchnorm_backprop_delta_check) ...@@ -206,9 +208,10 @@ TEST(type_prop, batchnorm_backprop_delta_check)
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, delta); make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, delta);
FAIL() << "Deduced type should disagree with c-tor arguments"; FAIL() << "Deduced type should disagree with c-tor arguments";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("delta shape is expected to be equal to input shape")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Shape of delta must match the shape of the input data"));
} }
catch (...) catch (...)
{ {
...@@ -2887,10 +2890,10 @@ TEST(type_prop, conv_invalid_element_type_mismatch) ...@@ -2887,10 +2890,10 @@ TEST(type_prop, conv_invalid_element_type_mismatch)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with element type mismatch not detected"; FAIL() << "Invalid input with element type mismatch not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Convolution data batch and filter element types do not match")); std::string("Element types for data batch and filters do not match"));
} }
catch (...) catch (...)
{ {
...@@ -2910,12 +2913,10 @@ TEST(type_prop, conv_invalid_0d_input) ...@@ -2910,12 +2913,10 @@ TEST(type_prop, conv_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Convolution data batch input must have rank of at " std::string("Data batch input must have rank of at least 3"));
"least 3 (one batch axis, one input-channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -2935,12 +2936,10 @@ TEST(type_prop, conv_invalid_1d_input) ...@@ -2935,12 +2936,10 @@ TEST(type_prop, conv_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Convolution data batch input must have rank of at " std::string("Data batch input must have rank of at least 3"));
"least 3 (one batch axis, one input-channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -2960,12 +2959,10 @@ TEST(type_prop, conv_invalid_2d_input) ...@@ -2960,12 +2959,10 @@ TEST(type_prop, conv_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Convolution data batch input must have rank of at " std::string("Data batch input must have rank of at least 3"));
"least 3 (one batch axis, one input-channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -2985,9 +2982,9 @@ TEST(type_prop, conv_invalid_0_batch_size) ...@@ -2985,9 +2982,9 @@ TEST(type_prop, conv_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution data batch size is zero.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch size is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3007,9 +3004,9 @@ TEST(type_prop, conv_invalid_0_input_channels) ...@@ -3007,9 +3004,9 @@ TEST(type_prop, conv_invalid_0_input_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 input channels not detected"; FAIL() << "Invalid input with 0 input channels not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution requires at least one input channel.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Input channel count is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3029,11 +3026,10 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many) ...@@ -3029,11 +3026,10 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many filter dimensions not detected"; FAIL() << "Invalid input with too many filter dimensions not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Filter input must have rank equal to the data batch"));
std::string("Convolution filter input must have rank of 2 + n_spatial_dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -3053,11 +3049,10 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few) ...@@ -3053,11 +3049,10 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few filter dimensions not detected"; FAIL() << "Invalid input with too few filter dimensions not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Filter input must have rank equal to the data batch"));
std::string("Convolution filter input must have rank of 2 + n_spatial_dimensions."));
} }
catch (...) catch (...)
{ {
...@@ -3077,9 +3072,9 @@ TEST(type_prop, conv_invalid_0_output_channels) ...@@ -3077,9 +3072,9 @@ TEST(type_prop, conv_invalid_0_output_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 output channels not detected"; FAIL() << "Invalid input with 0 output channels not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution requires at least one output channel.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Output channel count for filters is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3099,11 +3094,11 @@ TEST(type_prop, conv_invalid_input_channel_mismatch) ...@@ -3099,11 +3094,11 @@ TEST(type_prop, conv_invalid_input_channel_mismatch)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with channel count mismatch not detected"; FAIL() << "Invalid input with channel count mismatch not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Input channel count for filters (3) does not match the "
std::string("Convolution data batch and filter input channel counts do not match.")); "number of channels in the data batch (2)"));
} }
catch (...) catch (...)
{ {
...@@ -3123,11 +3118,12 @@ TEST(type_prop, conv_invalid_movement_stride_rank) ...@@ -3123,11 +3118,12 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution window movement stride rank does not " error.what(),
"match number of spatial dimensions.")); std::string(
"Rank of window movement strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3147,11 +3143,12 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank) ...@@ -3147,11 +3143,12 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong window dilation stride rank not detected"; FAIL() << "Invalid input with wrong window dilation stride rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution window dilation stride rank does not " error.what(),
"match number of spatial dimensions.")); std::string(
"Rank of window dilation strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3177,11 +3174,12 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank) ...@@ -3177,11 +3174,12 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong data dilation stride rank not detected"; FAIL() << "Invalid input with wrong data dilation stride rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution data dilation stride rank does not " error.what(),
"match number of spatial dimensions.")); std::string(
"Rank of data dilation strides does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3206,11 +3204,12 @@ TEST(type_prop, conv_invalid_padding_below_rank) ...@@ -3206,11 +3204,12 @@ TEST(type_prop, conv_invalid_padding_below_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-below rank not detected"; FAIL() << "Invalid input with wrong padding-below rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution padding-below rank does not " error.what(),
"match number of spatial dimensions.")); std::string(
"Rank of the padding below does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3235,11 +3234,12 @@ TEST(type_prop, conv_invalid_padding_above_rank) ...@@ -3235,11 +3234,12 @@ TEST(type_prop, conv_invalid_padding_above_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong padding-above rank not detected"; FAIL() << "Invalid input with wrong padding-above rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution padding-above rank does not " error.what(),
"match number of spatial dimensions.")); std::string(
"Rank of the padding above does not match the number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3264,12 +3264,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding) ...@@ -3264,12 +3264,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_negative_after_padding)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with negative-length post-padding spatial axis not detected"; FAIL() << "Invalid input with negative-length post-padding spatial axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string( std::string("Input dimension after padding and dilation is non-positive"));
"Convolution input spatial dimension after padding and dilation is negative."));
} }
catch (...) catch (...)
{ {
...@@ -3294,11 +3293,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding) ...@@ -3294,11 +3293,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_zero_after_padding)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length post-padding spatial axis not detected"; FAIL() << "Invalid input with zero-length post-padding spatial axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution input spatial dimension after " error.what(),
"dilation is zero even with padding.")); std::string("Input dimension after padding and dilation is non-positive"));
} }
catch (...) catch (...)
{ {
...@@ -3318,11 +3317,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_0) ...@@ -3318,11 +3317,11 @@ TEST(type_prop, conv_invalid_input_spatial_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Convolution input spatial dimension after " error.what(),
"dilation is zero even with padding.")); std::string("Input dimension after padding and dilation is non-positive"));
} }
catch (...) catch (...)
{ {
...@@ -3342,9 +3341,10 @@ TEST(type_prop, conv_invalid_window_size_0) ...@@ -3342,9 +3341,10 @@ TEST(type_prop, conv_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution window shape has a zero-length axis.")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Filters shape at spatial dimension 1 is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3364,9 +3364,10 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0) ...@@ -3364,9 +3364,10 @@ TEST(type_prop, conv_invalid_window_dilation_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected"; FAIL() << "Invalid input with wrong 0-length window dilation stride axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution window axis dilation stride is zero.")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window dilation stride at spatial dimension 1 is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3392,9 +3393,10 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0) ...@@ -3392,9 +3393,10 @@ TEST(type_prop, conv_invalid_data_dilation_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length data dilation stride axis not detected"; FAIL() << "Invalid input with wrong 0-length data dilation stride axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution data dilation stride is zero.")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data dilation stride at spatial dimension 1 is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3414,11 +3416,11 @@ TEST(type_prop, conv_invalid_dilated_window_too_large) ...@@ -3414,11 +3416,11 @@ TEST(type_prop, conv_invalid_dilated_window_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized dilated window not detected"; FAIL() << "Invalid input with oversized dilated window not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Convolution window after dilation is larger than the " std::string("Post-dilation window shape is smaller than the "
"spatial dimensions even with padding.")); "post-padding/dilation input item shape"));
} }
catch (...) catch (...)
{ {
...@@ -3438,9 +3440,10 @@ TEST(type_prop, conv_invalid_movement_stride_0) ...@@ -3438,9 +3440,10 @@ TEST(type_prop, conv_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong 0-length movement stride axis not detected"; FAIL() << "Invalid input with wrong 0-length movement stride axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Convolution window axis movement stride is zero.")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window movement stride at spatial dimension 0 is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3563,12 +3566,10 @@ TEST(type_prop, max_pool_invalid_0d_input) ...@@ -3563,12 +3566,10 @@ TEST(type_prop, max_pool_invalid_0d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 0D input not detected"; FAIL() << "Invalid 0D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Max-pool data batch input must have rank of at " std::string("Data input shape does not have rank of at least 3"));
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -3588,12 +3589,10 @@ TEST(type_prop, max_pool_invalid_1d_input) ...@@ -3588,12 +3589,10 @@ TEST(type_prop, max_pool_invalid_1d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 1D input not detected"; FAIL() << "Invalid 1D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Max-pool data batch input must have rank of at " std::string("Data input shape does not have rank of at least 3"));
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -3613,12 +3612,10 @@ TEST(type_prop, max_pool_invalid_2d_input) ...@@ -3613,12 +3612,10 @@ TEST(type_prop, max_pool_invalid_2d_input)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid 2D input not detected"; FAIL() << "Invalid 2D input not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Max-pool data batch input must have rank of at " std::string("Data input shape does not have rank of at least 3"));
"least 3 (one batch axis, one channel axis, at "
"least one spatial dimension)."));
} }
catch (...) catch (...)
{ {
...@@ -3638,9 +3635,9 @@ TEST(type_prop, max_pool_invalid_0_batch_size) ...@@ -3638,9 +3635,9 @@ TEST(type_prop, max_pool_invalid_0_batch_size)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 batch size not detected"; FAIL() << "Invalid input with 0 batch size not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Max-pool data batch size is zero.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch size is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3660,9 +3657,9 @@ TEST(type_prop, max_pool_invalid_0_channels) ...@@ -3660,9 +3657,9 @@ TEST(type_prop, max_pool_invalid_0_channels)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0 channels not detected"; FAIL() << "Invalid input with 0 channels not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Max-pool requires at least one feature channel.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Channel count is zero"));
} }
catch (...) catch (...)
{ {
...@@ -3682,11 +3679,11 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many) ...@@ -3682,11 +3679,11 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_many)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too many window dimensions not detected"; FAIL() << "Invalid input with too many window dimensions not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Max-pool window shape rank does not match number of spatial dimensions.")); std::string("Window shape rank does not match number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3706,11 +3703,11 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few) ...@@ -3706,11 +3703,11 @@ TEST(type_prop, max_pool_invalid_wrong_number_of_window_dimensions_too_few)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with too few window dimensions not detected"; FAIL() << "Invalid input with too few window dimensions not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Max-pool window shape rank does not match number of spatial dimensions.")); std::string("Window shape rank does not match number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3731,11 +3728,11 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank) ...@@ -3731,11 +3728,11 @@ TEST(type_prop, max_pool_invalid_movement_stride_rank)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with wrong movement stride rank not detected"; FAIL() << "Invalid input with wrong movement stride rank not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Max-pool window movement stride rank does not " error.what(),
"match number of spatial dimensions.")); std::string("Window movement stride rank does not match number of spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3755,10 +3752,11 @@ TEST(type_prop, max_pool_invalid_input_data_size_0) ...@@ -3755,10 +3752,11 @@ TEST(type_prop, max_pool_invalid_input_data_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length spatial axis not detected"; FAIL() << "Invalid input with zero-length spatial axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Max-pool input spatial dimension is zero even after padding.")); error.what(),
std::string("Data input spatial dimension 0 has zero length even after padding"));
} }
catch (...) catch (...)
{ {
...@@ -3778,9 +3776,9 @@ TEST(type_prop, max_pool_invalid_window_size_0) ...@@ -3778,9 +3776,9 @@ TEST(type_prop, max_pool_invalid_window_size_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with zero-length window axis not detected"; FAIL() << "Invalid input with zero-length window axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Max-pool window shape has a zero-length axis.")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Window shape dimension 1 has zero length"));
} }
catch (...) catch (...)
{ {
...@@ -3800,11 +3798,11 @@ TEST(type_prop, max_pool_invalid_dilated_too_large) ...@@ -3800,11 +3798,11 @@ TEST(type_prop, max_pool_invalid_dilated_too_large)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with oversized window not detected"; FAIL() << "Invalid input with oversized window not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), EXPECT_HAS_SUBSTRING(
std::string("Max-pool window shape is larger than the spatial " error.what(),
"dimensions even after padding.")); std::string("Window shape after padding is larger than the spatial dimensions"));
} }
catch (...) catch (...)
{ {
...@@ -3825,9 +3823,10 @@ TEST(type_prop, max_pool_invalid_movement_stride_0) ...@@ -3825,9 +3823,10 @@ TEST(type_prop, max_pool_invalid_movement_stride_0)
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Invalid input with 0-length movement stride axis not detected"; FAIL() << "Invalid input with 0-length movement stride axis not detected";
} }
catch (const ngraph_error& error) catch (const NodeValidationError& error)
{ {
EXPECT_EQ(error.what(), std::string("Max-pool window axis movement stride is zero.")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window movement strides dimension 0 has zero length"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment