Unverified Commit e9443168 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Error messages for Convolution, BatchNorm, MaxPool (#1535)

parent c386da90
......@@ -36,77 +36,55 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types()
{
m_bn_input_shape = get_input_shape(INPUT);
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error("input tensor must have at least one channel for batch normalization");
}
NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
<< ").";
NODE_VALIDATION_ASSERT(this, m_bn_input_shape[1] != 0)
<< "Input argument's channel dimension must have size of at least 1 (input argument shape: "
<< m_bn_input_shape << ").";
auto& et = get_input_element_type(INPUT);
auto in_size = get_input_size();
NODE_VALIDATION_ASSERT(this, in_size == 3 || in_size == 5)
<< "Argument count must be either 3 or 5 (received argument count: " << in_size << ").";
Shape channel_shape{m_bn_input_shape[1]};
if (in_size == 3)
{
set_output_size(3);
this->m_bn_mean_shape.push_back(m_bn_input_shape[1]);
m_bn_mean_shape = channel_shape;
set_output_type(1, et, m_bn_mean_shape);
this->m_bn_variance_shape.push_back(m_bn_input_shape[1]);
m_bn_variance_shape = channel_shape;
set_output_type(2, et, m_bn_variance_shape);
}
else if (in_size == 5)
{
set_output_size(1);
}
else
{
throw ngraph_error("Invalid BatchNorm args");
set_output_size(1);
}
set_output_type(0, et, m_bn_input_shape);
Shape channel_shape{m_bn_input_shape[1]};
const char* input_names[]{"gamma", "beta", "input", "mean", "variance"};
for (size_t i = 0; i < get_input_size(); i++)
{
if (i == 2)
if (i == INPUT)
{
continue;
}
if (get_input_element_type(i) != et)
{
std::stringstream err_msg;
err_msg << "The element type " << get_input_element_type(i) << " of input "
<< input_names[i] << " isn't equal to the input data's type " << et;
throw ngraph_error(err_msg.str());
}
if (get_input_shape(i) != channel_shape)
{
std::stringstream err_msg;
err_msg << "The shape " << get_input_shape(i) << " of " << input_names[i]
<< " isn't equal to input channel's shape " << channel_shape;
throw ngraph_error(err_msg.str());
}
}
for (size_t index = 0; index < get_input_size(); index++)
{
if (index != INPUT && get_input_shape(index).size() != 1)
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
<< "Element type of " << input_names[i] << " (" << get_input_element_type(i)
<< ") is not equal to the element type of input (" << et << ").";
if (index != INPUT && get_input_shape(index)[0] != m_bn_input_shape[1])
{
auto err_msg = std::string(input_names[index]) +
" shape should match the input channel size (" +
std::to_string(m_bn_input_shape[1]) + ",)";
throw ngraph_error(err_msg.c_str());
}
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
<< "Shape of " << input_names[i] << " must match the channel dimension of the "
<< "input data (expected shape: " << channel_shape << ", actual shape of "
<< input_names[i] << ": " << get_input_shape(i)
<< ", shape of input: " << m_bn_input_shape << ").";
}
}
......@@ -127,14 +105,19 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const
{
if (this->m_training)
check_new_args_count(this, new_args);
if (m_training)
{
// FIXME(amprocte): is this redundant?
NODE_VALIDATION_ASSERT(this, new_args.size() == 3 || new_args.size() == 5);
if (new_args.size() == 3)
{
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else if (new_args.size() == 5)
else
{
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
......@@ -144,17 +127,11 @@ std::shared_ptr<ngraph::Node>
new_args.at(4),
true);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
else
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
NODE_VALIDATION_ASSERT(this, new_args.size() == 5);
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
new_args.at(1),
......@@ -183,45 +160,37 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{
set_output_size(3);
if (get_input_shape(INPUT).size() != 4)
{
throw ngraph_error("Input expected to be a 4D tensor");
}
NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
<< "Input data shape is not a 4D tensor (input data shape: " << get_input_shape(INPUT)
<< ").";
auto et = get_input_element_type(INPUT);
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance", "delta"};
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_element_type(i) != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
Shape channel_shape{get_input_shape(INPUT).at(1)};
Shape channel_shape{get_input_shape(INPUT)[1]};
for (size_t i = 0; i < get_input_size(); i++)
{
if (i == 2 || i == 5) // don't check input and delta
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == et)
<< "Element type of " << input_names[i] << " (" << get_input_element_type(i)
<< ") is not equal to the element type of input (" << et << ").";
// Note that the shape of delta, a special case, will be checked after the loop.
if (i == DELTA || i == INPUT)
{
continue;
}
if (get_argument(i)->get_shape() != channel_shape)
{
auto err_msg = std::string("The shape of ") + input_names[i] +
" isn't equal to input channel's shape";
throw ngraph_error(err_msg.c_str());
}
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == channel_shape)
<< "Shape of " << input_names[i] << " must match the channel dimension of the "
<< "input data (expected shape: " << channel_shape << ", actual shape of "
<< input_names[i] << ": " << get_input_shape(i)
<< ", shape of input: " << get_input_shape(INPUT) << ").";
}
if (get_input_shape(DELTA) != get_input_shape(INPUT))
{
throw ngraph_error("delta shape is expected to be equal to input shape");
}
NODE_VALIDATION_ASSERT(this, get_input_shape(DELTA) == get_input_shape(INPUT))
<< "Shape of delta must match the shape of the input data (expected shape: "
<< get_input_shape(INPUT) << ", actual shape of delta: " << get_input_shape(DELTA) << ").";
set_output_type(0, get_input_element_type(INPUT), get_input_shape(INPUT));
set_output_type(1, get_input_element_type(GAMMA), get_input_shape(GAMMA));
......@@ -231,10 +200,7 @@ void ngraph::op::BatchNormBackprop::validate_and_infer_types()
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormBackprop::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 6)
{
throw ngraph_error("Incorrect number of new arguments");
}
check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormBackprop>(epsilon,
new_args.at(0),
new_args.at(1),
......
This diff is collapsed.
......@@ -155,8 +155,8 @@ namespace ngraph
Strides m_data_dilation_strides;
private:
static Strides default_strides(const Shape& data_batch_shape);
static CoordinateDiff default_padding(const Shape& data_batch_shape);
static Strides default_strides(const Node* node, const Shape& data_batch_shape);
static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape);
};
/// \brief Data batch backprop for batched convolution operation.
......@@ -356,7 +356,8 @@ namespace ngraph
namespace util
{
Shape infer_convolution_output_shape(const Shape& data_batch_shape,
Shape infer_convolution_output_shape(const Node* node,
const Shape& data_batch_shape,
const Shape& filters_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
......@@ -368,8 +369,7 @@ namespace ngraph
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
const std::string& error_prefix);
size_t output_channel_axis_result);
}
}
}
This diff is collapsed.
......@@ -109,7 +109,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
set_output_type(0,
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
......@@ -121,8 +122,8 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
1 /* output_channel_axis_result, */
));
}
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
......@@ -322,7 +323,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
set_output_type(0,
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
......@@ -334,8 +336,8 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
1 /* output_channel_axis_result, */
));
}
std::shared_ptr<Node> op::ConvolutionBiasAdd::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -68,7 +68,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
set_output_type(0,
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
util::infer_convolution_output_shape(this,
data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
......@@ -80,8 +81,8 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
1 /* output_channel_axis_result, */
));
}
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment