Commit 211b70f8 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

Convolution Bias bug fix (#766)

* fixed two bugs.

* refactor infer shape.

* clang format.

* reverted some changes.

* moved infer_convolution_output_shape() to outside class

* moved infer_convolution_output_shape to op::util namespace
parent 23195230
......@@ -26,7 +26,7 @@
using namespace std;
using namespace ngraph;
static Shape infer_convolution_output_shape(const Shape& data_batch_shape,
Shape op::util::infer_convolution_output_shape(const Shape& data_batch_shape,
const Shape& filters_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
......@@ -39,7 +39,7 @@ static Shape infer_convolution_output_shape(const Shape& data_batch_shape,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
string error_prefix)
const string& error_prefix)
{
if (batch_axis_data > 1 || input_channel_axis_data > 1 || input_channel_axis_filters > 1 ||
output_channel_axis_filters > 1 || batch_axis_result > 1 || output_channel_axis_result > 1)
......@@ -262,7 +262,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
}
set_value_type_checked(data_batch_et,
infer_convolution_output_shape(data_batch_shape,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
......@@ -455,7 +455,7 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
}
Shape inferred_convolution_output_shape =
infer_convolution_output_shape(output_delta_shape,
util::infer_convolution_output_shape(output_delta_shape,
filters_shape,
m_window_movement_strides_backward,
m_window_dilation_strides_backward,
......@@ -552,7 +552,7 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
}
Shape inferred_convolution_output_shape =
infer_convolution_output_shape(data_batch_shape,
util::infer_convolution_output_shape(data_batch_shape,
output_delta_shape,
m_window_movement_strides_backward,
m_window_dilation_strides_backward,
......
......@@ -340,5 +340,23 @@ namespace ngraph
CoordinateDiff m_padding_above_backward;
Strides m_data_dilation_strides_backward;
};
namespace util
{
Shape infer_convolution_output_shape(const Shape& data_batch_shape,
const Shape& filters_shape,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
size_t batch_axis_data,
size_t input_channel_axis_data,
size_t input_channel_axis_filters,
size_t output_channel_axis_filters,
size_t batch_axis_result,
size_t output_channel_axis_result,
const std::string& error_prefix);
}
}
}
......@@ -16,9 +16,10 @@
#include <numeric>
#include "conv_bias.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -57,11 +58,39 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
auto& data_batch_shape = data_batch->get_shape();
auto& data_batch_et = data_batch->get_element_type();
auto& filters_shape = filters->get_shape();
auto& filters_et = filters->get_element_type();
//
// Make sure data batch and filter element types match.
//
if (data_batch_et != filters_et)
{
throw ngraph_error("Convolution data batch and filter element types do not match");
}
set_value_type_checked(data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0,
1,
1,
0,
0,
1,
""));
}
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment