Commit e42e5815 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

Conv+bias shape check for better error detection (#1176)

* Reshape bias to 1D for conv + bias bprop fusion

* Reshape goe2 back to 2D before replacing

* added shape checks to validate conv+bias op.

* removed conv+bias backprop merge for separate PR review.

* fixed conv_bias_bprop test.

* minor changes to error messages.
parent f243d035
...@@ -25,6 +25,30 @@ ...@@ -25,6 +25,30 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void op::util::validate_convbias_shapes(const Shape& data_shape,
const Shape& filters_shape,
const Shape& bias_shape)
{
if (bias_shape.size() != 1)
{
throw ngraph_error("Convolution+bias bias is expected to be 1D, but has shape: " +
vector_to_string(bias_shape));
}
if (bias_shape[0] != filters_shape[0])
{
throw ngraph_error(
"Convolution+bias bias element size does not match number of filters. bias_size = " +
std::to_string(bias_shape[0]) + ", num_filters = " + std::to_string(filters_shape[0]));
}
if (data_shape[1] != filters_shape[1])
{
throw ngraph_error(
"Convolution+bias data and filter have different number of channels: data_channel=" +
std::to_string(data_shape[1]) + ", filter_channel= " +
std::to_string(filters_shape[1]));
}
}
op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv, op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv,
const shared_ptr<Node>& bias) const shared_ptr<Node>& bias)
: RequiresTensorViewArgs("ConvolutionBias", : RequiresTensorViewArgs("ConvolutionBias",
...@@ -40,6 +64,9 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv, ...@@ -40,6 +64,9 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<op::Convolution>& conv,
throw ngraph_error("Convolution's element type isn't equal to bias!"); throw ngraph_error("Convolution's element type isn't equal to bias!");
} }
util::validate_convbias_shapes(
conv->get_argument(0)->get_shape(), conv->get_argument(1)->get_shape(), bias->get_shape());
set_value_type_checked(conv->get_element_type(), conv->get_shape()); set_value_type_checked(conv->get_element_type(), conv->get_shape());
} }
...@@ -70,6 +97,7 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -70,6 +97,7 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
{ {
throw ngraph_error("Convolution data batch and filter element types do not match"); throw ngraph_error("Convolution data batch and filter element types do not match");
} }
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
set_value_type_checked( set_value_type_checked(
data_batch_et, data_batch_et,
...@@ -180,6 +208,7 @@ op::ConvolutionBiasBackpropFiltersBias::ConvolutionBiasBackpropFiltersBias( ...@@ -180,6 +208,7 @@ op::ConvolutionBiasBackpropFiltersBias::ConvolutionBiasBackpropFiltersBias(
"match"); "match");
} }
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias_shape);
// Forward Backward // Forward Backward
// Window movement strides q p_f // Window movement strides q p_f
// Window dilation strides p_f q // Window dilation strides p_f q
...@@ -237,6 +266,9 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<op::Convolution ...@@ -237,6 +266,9 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<op::Convolution
, m_data_dilation_strides(conv->get_data_dilation_strides()) , m_data_dilation_strides(conv->get_data_dilation_strides())
, m_with_relu(with_relu) , m_with_relu(with_relu)
{ {
util::validate_convbias_shapes(conv->get_argument(0)->get_shape(),
conv->get_argument(1)->get_shape(),
conv->get_argument(2)->get_shape());
set_value_type_checked(conv->get_element_type(), conv->get_shape()); set_value_type_checked(conv->get_element_type(), conv->get_shape());
} }
...@@ -271,6 +303,7 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat ...@@ -271,6 +303,7 @@ op::ConvolutionBiasAdd::ConvolutionBiasAdd(const std::shared_ptr<Node>& data_bat
throw ngraph_error("Convolution data batch and filter element types do not match"); throw ngraph_error("Convolution data batch and filter element types do not match");
} }
util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
set_value_type_checked( set_value_type_checked(
data_batch_et, data_batch_et,
util::infer_convolution_output_shape(data_batch_shape, util::infer_convolution_output_shape(data_batch_shape,
......
...@@ -187,5 +187,12 @@ namespace ngraph ...@@ -187,5 +187,12 @@ namespace ngraph
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
bool m_with_relu; bool m_with_relu;
}; };
namespace util
{
void validate_convbias_shapes(const Shape& data_shape,
const Shape& filters_shape,
const Shape& bias_shape);
}
} }
} }
...@@ -1031,7 +1031,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_relu() ...@@ -1031,7 +1031,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_relu()
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{1}); auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{shape[0]});
auto conv_bias = std::make_shared<op::ConvolutionBias>(data_batch, auto conv_bias = std::make_shared<op::ConvolutionBias>(data_batch,
filters, filters,
...@@ -1099,7 +1099,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add() ...@@ -1099,7 +1099,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add()
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{1}); auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{shape[0]});
auto pconv = std::make_shared<op::ConvolutionBias>(data_batch, auto pconv = std::make_shared<op::ConvolutionBias>(data_batch,
filters, filters,
...@@ -1202,7 +1202,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add_relu() ...@@ -1202,7 +1202,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias_add_relu()
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = std::make_shared<pattern::op::Label>(element::f32, shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{1}); auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{shape[0]});
auto add_input = std::make_shared<pattern::op::Label>(element::f32, shape); auto add_input = std::make_shared<pattern::op::Label>(element::f32, shape);
auto pconv = std::make_shared<op::ConvolutionBiasAdd>(data_batch, auto pconv = std::make_shared<op::ConvolutionBiasAdd>(data_batch,
......
...@@ -288,6 +288,8 @@ namespace ngraph ...@@ -288,6 +288,8 @@ namespace ngraph
if (use_bias) if (use_bias)
{ {
auto arg2_shape = node->get_input_shape(2); auto arg2_shape = node->get_input_shape(2);
ngraph::op::util::validate_convbias_shapes(
arg0_shape, arg1_shape, arg2_shape);
memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end()); memory::dims mkldnn_arg2_shape(arg2_shape.begin(), arg2_shape.end());
const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any); const memory::desc bias_desc(mkldnn_arg2_shape, et, memory::format::any);
try try
...@@ -611,6 +613,8 @@ namespace ngraph ...@@ -611,6 +613,8 @@ namespace ngraph
if (use_bias) if (use_bias)
{ {
auto bias_shape = node->get_output_shape(1); auto bias_shape = node->get_output_shape(1);
ngraph::op::util::validate_convbias_shapes(
data_shape, filters_shape, bias_shape);
memory::dims mkldnn_bias_shape(bias_shape.begin(), bias_shape.end()); memory::dims mkldnn_bias_shape(bias_shape.begin(), bias_shape.end());
const memory::desc bias_desc(mkldnn_bias_shape, et, memory::format::any); const memory::desc bias_desc(mkldnn_bias_shape, et, memory::format::any);
bwd_desc.reset( bwd_desc.reset(
......
...@@ -678,8 +678,8 @@ TEST(cpu_fusion, conv_bias_bprop) ...@@ -678,8 +678,8 @@ TEST(cpu_fusion, conv_bias_bprop)
auto data_batch = std::make_shared<op::Parameter>(element::f32, shape); auto data_batch = std::make_shared<op::Parameter>(element::f32, shape);
auto filters = std::make_shared<op::Parameter>(element::f32, shape); auto filters = std::make_shared<op::Parameter>(element::f32, shape);
auto delta = std::make_shared<op::Parameter>(element::f32, shape); auto delta = std::make_shared<op::Parameter>(element::f32, shape);
auto bias = make_shared<op::Parameter>(element::f32, Shape{}); auto bias = make_shared<op::Parameter>(element::f32, Shape{shape[0]});
auto pbroadcast = std::make_shared<op::Broadcast>(bias, shape, AxisSet{0, 1, 2, 3}); auto pbroadcast = std::make_shared<op::Broadcast>(bias, shape, AxisSet{1, 2, 3});
auto conv = std::make_shared<op::Convolution>(data_batch, filters); auto conv = std::make_shared<op::Convolution>(data_batch, filters);
auto conv_bias = std::make_shared<op::Add>(conv, pbroadcast); auto conv_bias = std::make_shared<op::Add>(conv, pbroadcast);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment