Unverified Commit e9443168 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Error messages for Convolution, BatchNorm, MaxPool (#1535)

parent c386da90
...@@ -36,77 +36,55 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -36,77 +36,55 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types() void ngraph::op::BatchNorm::validate_and_infer_types()
{ {
m_bn_input_shape = get_input_shape(INPUT); m_bn_input_shape = get_input_shape(INPUT);
if (m_bn_input_shape.size() < 2) NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
{ << "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2"); << ").";
}
if (m_bn_input_shape[1] == 0) NODE_VALIDATION_ASSERT(this, m_bn_input_shape[1] != 0)
{ << "Input argument's channel dimension must have size of at least 1 (input argument shape: "
throw ngraph_error("input tensor must have at least one channel for batch normalization"); << m_bn_input_shape << ").";
}
auto& et = get_input_element_type(INPUT); auto& et = get_input_element_type(INPUT);
auto in_size = get_input_size(); auto in_size = get_input_size();
NODE_VALIDATION_ASSERT(this, in_size == 3 || in_size == 5)
<< "Argument count must be either 3 or 5 (received argument count: " << in_size << ").";
Shape channel_shape{m_bn_input_shape[1]};
if (in_size == 3) if (in_size == 3)
{ {
set_output_size(3); set_output_size(3);
this->m_bn_mean_shape.push_back(m_bn_input_shape[1]); m_bn_mean_shape = channel_shape;
set_output_type(1, et, m_bn_mean_shape); set_output_type(1, et, m_bn_mean_shape);
this->m_bn_variance_shape.push_back(m_bn_input_shape[1]); m_bn_variance_shape = channel_shape;
set_output_type(2, et, m_bn_variance_shape); set_output_type(2, et, m_bn_variance_shape);
} }
else if (in_size == 5)
{
set_output_size(1);
}
else else
{ {
throw ngraph_error("Invalid BatchNorm args"); set_output_size(1);
} }
set_output_type(0, et, m_bn_input_shape); set_output_type(0, et, m_bn_input_shape);
Shape channel_shape{m_bn_input_shape[1]};
const char* input_names[]{"gamma", "beta", "input", "mean", "variance"}; const char* input_names[]{"gamma", "beta", "input", "mean", "variance"};
for (size_t i = 0; i < get_input_size(); i++) for (size_t i = 0; i < get_input_size(); i++)
{ {
if (i == 2) if (i == INPUT)
{ {
continue; continue;
} }
if (get_input_element_type(i) != et)
{
std::stringstream err_msg;
err_msg << "The element type " << get_input_element_type(i) << " of input "
<< input_names[i] << " isn't equal to the input data's type " << et;
throw ngraph_error(err_msg.str());
}
if (get_input_shape(i) != channel_shape)
{
std::stringstream err_msg;
err_msg << "The shape " << get_input_shape(i) << " of " << input_names[i]
<< " isn't equal to input channel's shape " << channel_shape;
throw ngraph_error(err_msg.str());
}
}
for (size_t index = 0; index < get_input_size(); index++) NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
{ << "Element type of " << input_names[i] << " (" << get_input_element_type(i)
if (index != INPUT && get_input_shape(index).size() != 1) << ") is not equal to the element type of input (" << et << ").";
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
if (index != INPUT && get_input_shape(index)[0] != m_bn_input_shape[1]) NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
{ << "Shape of " << input_names[i] << " must match the channel dimension of the "
auto err_msg = std::string(input_names[index]) + << "input data (expected shape: " << channel_shape << ", actual shape of "
" shape should match the input channel size (" + << input_names[i] << ": " << get_input_shape(i)
std::to_string(m_bn_input_shape[1]) + ",)"; << ", shape of input: " << m_bn_input_shape << ").";
throw ngraph_error(err_msg.c_str());
}
} }
} }
...@@ -127,14 +105,19 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -127,14 +105,19 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const
{ {
if (this->m_training) check_new_args_count(this, new_args);
if (m_training)
{ {
// FIXME(amprocte): is this redundant?
NODE_VALIDATION_ASSERT(this, new_args.size() == 3 || new_args.size() == 5);
if (new_args.size() == 3) if (new_args.size() == 3)
{ {
return std::make_shared<BatchNorm>( return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2)); m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
} }
else if (new_args.size() == 5) else
{ {
return std::make_shared<BatchNorm>(m_epsilon, return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0), new_args.at(0),
...@@ -144,17 +127,11 @@ std::shared_ptr<ngraph::Node> ...@@ -144,17 +127,11 @@ std::shared_ptr<ngraph::Node>
new_args.at(4), new_args.at(4),
true); true);
} }
else
{
throw ngraph_error("Incorrect number of new arguments");
}
} }
else else
{ {
if (new_args.size() != 5) NODE_VALIDATION_ASSERT(this, new_args.size() == 5);
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(m_epsilon, return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0), new_args.at(0),
new_args.at(1), new_args.at(1),
...@@ -183,45 +160,37 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types() ...@@ -183,45 +160,37 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{ {
set_output_size(3); set_output_size(3);
if (get_input_shape(INPUT).size() != 4) NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
{ << "Input data shape is not a 4D tensor (input data shape: " << get_input_shape(INPUT)
throw ngraph_error("Input expected to be a 4D tensor"); << ").";
}
auto et = get_input_element_type(INPUT); auto et = get_input_element_type(INPUT);
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"}; const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"};
for (size_t i = 0; i < get_input_size(); i++) Shape channel_shape{get_input_shape(INPUT)[1]};
{
if (get_input_element_type(i) != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
Shape channel_shape{get_input_shape(INPUT).at(1)};
for (size_t i = 0; i < get_input_size(); i++) for (size_t i = 0; i < get_input_size(); i++)
{ {
if (i == 2 || i == 5) // don't check input and delta NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
<< "Element type of " << input_names[i] << " (" << get_input_element_type(i)
<< ") is not equal to the element type of input (" << et << ").";
// Note that the shape of delta, a special case, will be checked after the loop.
if (i == DELTA || i == INPUT)
{ {
continue; continue;
} }
if (get_argument(i)->get_shape() != channel_shape) NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
{ << "Shape of " << input_names[i] << " must match the channel dimension of the "
auto err_msg = std::string("The shape of ") + input_names[i] + << "input data (expected shape: " << channel_shape << ", actual shape of "
" isn't equal to input channel's shape"; << input_names[i] << ": " << get_input_shape(i)
throw ngraph_error(err_msg.c_str()); << ", shape of input: " << get_input_shape(INPUT) << ").";
}
} }
if (get_input_shape(DELTA) != get_input_shape(INPUT)) NODE_VALIDATION_ASSERT(this, get_input_shape(DELTA) == get_input_shape(INPUT))
{ << "Shape of delta must match the shape of the input data (expected shape: "
throw ngraph_error("delta shape is expected to be equal to input shape"); << get_input_shape(INPUT) << ", actual shape of delta: " << get_input_shape(DELTA) << ").";
}
set_output_type(0, get_input_element_type(INPUT), get_input_shape(INPUT)); set_output_type(0, get_input_element_type(INPUT), get_input_shape(INPUT));
set_output_type(1, get_input_element_type(GAMMA), get_input_shape(GAMMA)); set_output_type(1, get_input_element_type(GAMMA), get_input_shape(GAMMA));
...@@ -231,10 +200,7 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types() ...@@ -231,10 +200,7 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 6) check_new_args_count(this, new_args);
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<op::BatchNormBackprop>(epsilon, return std::make_shared<op::BatchNormBackprop>(epsilon,
new_args.at(0), new_args.at(0),
new_args.at(1), new_args.at(1),
......
This diff is collapsed.
...@@ -155,8 +155,8 @@ namespace ngraph ...@@ -155,8 +155,8 @@ namespace ngraph
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
private: private:
static Strides default_strides(const Shape& data_batch_shape); static Strides default_strides(const Node* node, const Shape& data_batch_shape);
static CoordinateDiff default_padding(const Shape& data_batch_shape); static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape);
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
...@@ -356,7 +356,8 @@ namespace ngraph ...@@ -356,7 +356,8 @@ namespace ngraph
namespace util namespace util
{ {
Shape infer_convolution_output_shape(const Shape& data_batch_shape, Shape infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape, const Shape& filters_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
...@@ -368,8 +369,7 @@ namespace ngraph ...@@ -368,8 +369,7 @@ namespace ngraph
size_t input_channel_axis_filters, size_t input_channel_axis_filters,
size_t output_channel_axis_filters, size_t output_channel_axis_filters,
size_t batch_axis_result, size_t batch_axis_result,
size_t output_channel_axis_result, size_t output_channel_axis_result);
const std::string& error_prefix);
} }
} }
} }
This diff is collapsed.
...@@ -109,7 +109,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -109,7 +109,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -121,8 +122,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -121,8 +122,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
...@@ -322,7 +323,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat ...@@ -322,7 +323,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape()); util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -334,8 +336,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat ...@@ -334,8 +336,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -68,7 +68,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch, ...@@ -68,7 +68,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
set_output_type(0, set_output_type(0,
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape, filters_shape,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -80,8 +81,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch, ...@@ -80,8 +81,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */ 1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */ 0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */ 0, /* batch_axis_result, */
1, /* output_channel_axis_result, */ 1 /* output_channel_axis_result, */
"")); ));
} }
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment