Commit 324efb18 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

fix conv_relu's copy_with_new_args c-tor (#778)

parent 5b760fff
...@@ -71,21 +71,22 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch, ...@@ -71,21 +71,22 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
throw ngraph_error("Convolution data batch and filter element types do not match"); throw ngraph_error("Convolution data batch and filter element types do not match");
} }
set_value_type_checked(data_batch_et, set_value_type_checked(
util::infer_convolution_output_shape(data_batch_shape, data_batch_et,
filters_shape, util::infer_convolution_output_shape(data_batch_shape,
window_movement_strides, filters_shape,
window_dilation_strides, window_movement_strides,
padding_below, window_dilation_strides,
padding_above, padding_below,
data_dilation_strides, padding_above,
0, data_dilation_strides,
1, 0, /* batch_axis_data, */
1, 1, /* input_channel_axis_data, */
0, 1, /* input_channel_axis_filters, */
0, 0, /* output_channel_axis_filters, */
1, 0, /* batch_axis_result, */
"")); 1, /* output_channel_axis_result, */
""));
} }
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -49,6 +49,35 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch, ...@@ -49,6 +49,35 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides) , m_data_dilation_strides(data_dilation_strides)
{ {
auto& data_batch_shape = data_batch->get_shape();
auto& data_batch_et = data_batch->get_element_type();
auto& filters_shape = filters->get_shape();
auto& filters_et = filters->get_element_type();
//
// Make sure data batch and filter element types match.
//
if (data_batch_et != filters_et)
{
throw ngraph_error("Convolution data batch and filter element types do not match");
}
set_value_type_checked(
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
} }
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment