Commit 324efb18 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

fix conv_relu's copy_with_new_args c-tor (#778)

parent 5b760fff
......@@ -71,21 +71,22 @@ op::ConvolutionBias::ConvolutionBias(const shared_ptr<Node>& data_batch,
throw ngraph_error("Convolution data batch and filter element types do not match");
}
set_value_type_checked(data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0,
1,
1,
0,
0,
1,
""));
set_value_type_checked(
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
}
shared_ptr<Node> op::ConvolutionBias::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -49,6 +49,35 @@ op::ConvolutionRelu::ConvolutionRelu(const std::shared_ptr<Node>& data_batch,
, m_padding_above(padding_above)
, m_data_dilation_strides(data_dilation_strides)
{
auto& data_batch_shape = data_batch->get_shape();
auto& data_batch_et = data_batch->get_element_type();
auto& filters_shape = filters->get_shape();
auto& filters_et = filters->get_element_type();
//
// Make sure data batch and filter element types match.
//
if (data_batch_et != filters_et)
{
throw ngraph_error("Convolution data batch and filter element types do not match");
}
set_value_type_checked(
data_batch_et,
util::infer_convolution_output_shape(data_batch_shape,
filters_shape,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
0, /* batch_axis_data, */
1, /* input_channel_axis_data, */
1, /* input_channel_axis_filters, */
0, /* output_channel_axis_filters, */
0, /* batch_axis_result, */
1, /* output_channel_axis_result, */
""));
}
std::shared_ptr<Node> op::ConvolutionRelu::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment