Commit ccfcf4f9 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4λ: Convolution and backprops (#1890)

* Implement partial shape/type propagation for Convolution; fail for want of unit tests

* Implement unit tests for partial shapes/types for Convolution
parent 0d693fc3
...@@ -46,48 +46,38 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -46,48 +46,38 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
void op::Convolution::validate_and_infer_types() void op::Convolution::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) const PartialShape& data_batch_shape = get_input_partial_shape(0);
{ element::Type data_batch_et = get_input_element_type(0);
return; const PartialShape& filters_shape = get_input_partial_shape(1);
} element::Type filters_et = get_input_element_type(1);
auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0);
auto& filters_shape = get_input_shape(1);
auto& filters_et = get_input_element_type(1);
NODE_VALIDATION_ASSERT(this, data_batch_shape.size() >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
if (m_data_dilation_strides.size() == 0) if (m_data_dilation_strides.size() == 0)
{ {
m_data_dilation_strides = default_strides(this, data_batch_shape); m_data_dilation_strides = default_strides(this, data_batch_shape, filters_shape);
} }
if (m_window_movement_strides.size() == 0) if (m_window_movement_strides.size() == 0)
{ {
m_window_movement_strides = default_strides(this, data_batch_shape); m_window_movement_strides = default_strides(this, data_batch_shape, filters_shape);
} }
if (m_window_dilation_strides.size() == 0) if (m_window_dilation_strides.size() == 0)
{ {
m_window_dilation_strides = default_strides(this, data_batch_shape); m_window_dilation_strides = default_strides(this, data_batch_shape, filters_shape);
} }
if (m_padding_below.size() == 0) if (m_padding_below.size() == 0)
{ {
m_padding_below = default_padding(this, data_batch_shape); m_padding_below = default_padding(this, data_batch_shape, filters_shape);
} }
if (m_padding_above.size() == 0) if (m_padding_above.size() == 0)
{ {
m_padding_above = default_padding(this, data_batch_shape); m_padding_above = default_padding(this, data_batch_shape, filters_shape);
} }
element::Type result_et; element::Type result_et;
Shape result_shape; PartialShape result_shape;
std::tie(result_et, result_shape) = infer_convolution_forward(this, std::tie(result_et, result_shape) = infer_convolution_forward(this,
data_batch_et, data_batch_et,
...@@ -103,10 +93,26 @@ void op::Convolution::validate_and_infer_types() ...@@ -103,10 +93,26 @@ void op::Convolution::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
Strides op::Convolution::default_strides(const Node* node, const Shape& data_batch_shape) Strides op::Convolution::default_strides(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape)
{ {
NGRAPH_ASSERT(data_batch_shape.size() >= 2); size_t rank;
return Strides(data_batch_shape.size() - 2, 1);
if (data_batch_shape.rank().is_static() && static_cast<size_t>(data_batch_shape.rank()) >= 2)
{
rank = static_cast<size_t>(data_batch_shape.rank()) - 2;
}
else if (filters_shape.rank().is_static() && static_cast<size_t>(filters_shape.rank()) >= 2)
{
rank = static_cast<size_t>(filters_shape.rank()) - 2;
}
else
{
rank = 0;
}
return Strides(rank, 1);
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
...@@ -125,10 +131,26 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -125,10 +131,26 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
CoordinateDiff op::Convolution::default_padding(const Node* node, const Shape& data_batch_shape) CoordinateDiff op::Convolution::default_padding(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape)
{ {
NGRAPH_ASSERT(data_batch_shape.size() >= 2); size_t rank;
return CoordinateDiff(data_batch_shape.size() - 2, 0);
if (data_batch_shape.rank().is_static() && static_cast<size_t>(data_batch_shape.rank()) >= 2)
{
rank = static_cast<size_t>(data_batch_shape.rank()) - 2;
}
else if (filters_shape.rank().is_static() && static_cast<size_t>(filters_shape.rank()) >= 2)
{
rank = static_cast<size_t>(filters_shape.rank()) - 2;
}
else
{
rank = 0;
}
return CoordinateDiff(rank, 0);
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
...@@ -225,11 +247,6 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha ...@@ -225,11 +247,6 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types() void op::ConvolutionBackpropData::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
...@@ -260,13 +277,13 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -260,13 +277,13 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
// reference kernels that do the calculations of the backward parameters internally, or supply // reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.) // utility functions to do it.)
auto& filters_shape = get_input_shape(0); const PartialShape& filters_shape = get_input_partial_shape(0);
auto& filters_et = get_input_element_type(0); element::Type filters_et = get_input_element_type(0);
auto& delta_shape = get_input_shape(1); const PartialShape& delta_shape = get_input_partial_shape(1);
auto& delta_et = get_input_element_type(1); element::Type delta_et = get_input_element_type(1);
element::Type forward_result_et; element::Type forward_result_et;
Shape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this, infer_convolution_forward(this,
...@@ -280,16 +297,20 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -280,16 +297,20 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape) NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " << "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ")."; << "delta (" << delta_shape << ").";
set_output_type(0, delta_et, m_data_batch_shape); set_output_type(0, forward_result_et, m_data_batch_shape);
// //
// Compute parameters needed for backprop-as-convolution. // Compute parameters needed for backprop-as-convolution.
// //
size_t spatial_dim_count = delta_shape.size() - 2; // TODO(amprocte): Remove these fields, compute where needed.
//
if (delta_shape.is_static() && filters_shape.is_static())
{
size_t spatial_dim_count = static_cast<size_t>(delta_shape.rank()) - 2;
m_window_movement_strides_backward = m_data_dilation_strides_forward; m_window_movement_strides_backward = m_data_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_dilation_strides_forward; m_window_dilation_strides_backward = m_window_dilation_strides_forward;
...@@ -300,18 +321,21 @@ void op::ConvolutionBackpropData::validate_and_infer_types() ...@@ -300,18 +321,21 @@ void op::ConvolutionBackpropData::validate_and_infer_types()
for (size_t i = 0; i < spatial_dim_count; i++) for (size_t i = 0; i < spatial_dim_count; i++)
{ {
m_padding_below_backward[i] = m_padding_below_backward[i] = (static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
(filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i] - m_window_dilation_strides_forward[i] -
m_padding_below_forward[i]; m_padding_below_forward[i];
m_padding_above_backward[i] = m_padding_above_backward[i] =
(filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i] + (static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i] +
((m_padding_below_forward[i] + ((m_padding_below_forward[i] +
(m_data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] + (m_data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] - m_padding_above_forward[i] -
(filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) % (static_cast<ptrdiff_t>(filters_shape[i + 2]) - 1) *
m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]) - m_window_movement_strides_forward[i]) -
m_padding_above_forward[i]; m_padding_above_forward[i];
} }
}
} }
void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints, void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -420,11 +444,6 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( ...@@ -420,11 +444,6 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types() void op::ConvolutionBackpropFilters::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
...@@ -455,13 +474,13 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -455,13 +474,13 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
// reference kernels that do the calculations of the backward parameters internally, or supply // reference kernels that do the calculations of the backward parameters internally, or supply
// utility functions to do it.) // utility functions to do it.)
auto& data_batch_shape = get_input_shape(0); const PartialShape& data_batch_shape = get_input_partial_shape(0);
auto& data_batch_et = get_input_element_type(0); element::Type data_batch_et = get_input_element_type(0);
auto& delta_shape = get_input_shape(1); const PartialShape& delta_shape = get_input_shape(1);
auto& delta_et = get_input_element_type(1); element::Type delta_et = get_input_element_type(1);
element::Type forward_result_et; element::Type forward_result_et;
Shape forward_result_shape; PartialShape forward_result_shape;
std::tie(forward_result_et, forward_result_shape) = std::tie(forward_result_et, forward_result_shape) =
infer_convolution_forward(this, infer_convolution_forward(this,
...@@ -475,16 +494,20 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -475,16 +494,20 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_window_movement_strides_forward, m_window_movement_strides_forward,
m_window_dilation_strides_forward); m_window_dilation_strides_forward);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape) NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape (" << forward_result_shape << ") does not match shape of " << "Inferred forward output shape (" << forward_result_shape << ") does not match shape of "
<< "delta (" << delta_shape << ")."; << "delta (" << delta_shape << ").";
set_output_type(0, delta_et, m_filters_shape); set_output_type(0, forward_result_et, m_filters_shape);
// //
// Compute parameters needed for backprop-as-convolution. // Compute parameters needed for backprop-as-convolution.
// //
size_t spatial_dim_count = delta_shape.size() - 2; // TODO(amprocte): Remove these fields, compute where needed.
//
if (delta_shape.is_static() && data_batch_shape.is_static())
{
size_t spatial_dim_count = static_cast<size_t>(delta_shape.rank()) - 2;
m_window_movement_strides_backward = m_window_dilation_strides_forward; m_window_movement_strides_backward = m_window_dilation_strides_forward;
m_window_dilation_strides_backward = m_window_movement_strides_forward; m_window_dilation_strides_backward = m_window_movement_strides_forward;
...@@ -498,11 +521,13 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types() ...@@ -498,11 +521,13 @@ void op::ConvolutionBackpropFilters::validate_and_infer_types()
m_padding_above_backward[i] = m_padding_above_backward[i] =
m_padding_above_forward[i] - m_padding_above_forward[i] -
(m_padding_below_forward[i] + (m_padding_below_forward[i] +
(data_batch_shape[i + 2] - 1) * m_data_dilation_strides_forward[i] + (static_cast<ptrdiff_t>(data_batch_shape[i + 2]) - 1) *
m_data_dilation_strides_forward[i] +
m_padding_above_forward[i] - m_padding_above_forward[i] -
(m_filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) % (m_filters_shape[i + 2] - 1) * m_window_dilation_strides_forward[i]) %
m_window_movement_strides_forward[i]; m_window_movement_strides_forward[i];
} }
}
} }
shared_ptr<Node> shared_ptr<Node>
......
...@@ -155,8 +155,12 @@ namespace ngraph ...@@ -155,8 +155,12 @@ namespace ngraph
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
private: private:
static Strides default_strides(const Node* node, const Shape& data_batch_shape); static Strides default_strides(const Node* node,
static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape); const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
static CoordinateDiff default_padding(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
......
...@@ -125,52 +125,104 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -125,52 +125,104 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
// //
// Infers the output batch shape and element type for convolution fprop. // Infers the output batch shape and element type for convolution fprop.
// //
std::tuple<element::Type, Shape> std::tuple<element::Type, PartialShape>
ngraph::infer_convolution_forward(const Node* node, ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch, element::Type et_batch,
element::Type et_filters, element::Type et_filters,
const Shape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& filters_shape, const PartialShape& filters_shape,
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation) const Strides& filter_dilation)
{ {
NODE_VALIDATION_ASSERT(node, et_batch == et_filters) element::Type et_result;
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_batch, et_filters))
<< "Element types for data batch and filters do not match (data batch element type: " << "Element types for data batch and filters do not match (data batch element type: "
<< et_batch << ", filters element type: " << et_filters << ")."; << et_batch << ", filters element type: " << et_filters << ").";
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3) Rank data_batch_filters_rank{Rank::dynamic()};
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) " NODE_VALIDATION_ASSERT(
<< "(data batch shape: " << data_batch_shape << ")."; node, Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()))
<< "Data batch and filters rank do not match (data batch shape: " << data_batch_shape
<< ", filters shape: " << filters_shape << ").";
NODE_VALIDATION_ASSERT(node, filters_shape.size() >= 3) NODE_VALIDATION_ASSERT(node,
<< "Filters must have rank of at least 3 (one output-channel axis, " data_batch_filters_rank.is_dynamic() ||
static_cast<size_t>(data_batch_filters_rank) >= 3)
<< "Data batch and filters must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) " << "one input-channel axis, and at least one spatial dimension) "
<< "(filters shape: " << filters_shape << ")."; << "(data batch shape: " << data_batch_shape << ", filters shape: " << filters_shape
<< ").";
size_t batch_size = data_batch_shape[0]; Rank spatial_rank{Rank::dynamic()};
size_t data_channel_count = data_batch_shape[1]; NODE_VALIDATION_ASSERT(node,
Shape data_spatial_shape(data_batch_shape.begin() + 2, data_batch_shape.end()); Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()))
<< "Ranks for data item shape/filters shape (data batch has shape " << data_batch_shape
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << " and filters have shape "
<< filters_shape << ", so filters spatial rank is " << (filters_shape.rank() - 2)
<< "), data dilation (" << data_dilation << "), padding below (" << data_padding_below
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides
<< "), and filter dilation (" << filter_dilation << ") do not match.";
Dimension batch_size =
(data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
Dimension data_channel_count =
(data_batch_shape.rank().is_static() ? data_batch_shape[1] : Dimension::dynamic());
PartialShape data_spatial_shape(PartialShape::dynamic(spatial_rank));
Dimension filter_output_channel_count =
(filters_shape.rank().is_static() ? filters_shape[0] : Dimension::dynamic());
Dimension filter_input_channel_count =
(filters_shape.rank().is_static() ? filters_shape[1] : Dimension::dynamic());
PartialShape filter_spatial_shape(PartialShape::dynamic(spatial_rank));
//
// Note: spatial_rank is definitely static at this point.
//
for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
if (data_batch_shape.rank().is_static())
{
data_spatial_shape[i] = data_batch_shape[i + 2];
}
size_t filter_output_channel_count = filters_shape[0]; if (filters_shape.rank().is_static())
size_t filter_input_channel_count = filters_shape[1]; {
Shape filter_spatial_shape(filters_shape.begin() + 2, filters_shape.end()); filter_spatial_shape[i] = filters_shape[i + 2];
}
}
NODE_VALIDATION_ASSERT(node, batch_size > 0) << "Batch size is zero."; NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0)
<< "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, data_channel_count > 0) << "Data batch channel count is zero."; Dimension merged_channel_count;
NODE_VALIDATION_ASSERT(node, data_channel_count == filter_input_channel_count) NODE_VALIDATION_ASSERT(
node,
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count))
<< "Data batch channel count (" << data_channel_count << ") does not match filter input " << "Data batch channel count (" << data_channel_count << ") does not match filter input "
<< "channel count (" << filter_input_channel_count << ")."; << "channel count (" << filter_input_channel_count << ").";
NODE_VALIDATION_ASSERT(node, filter_output_channel_count > 0) NODE_VALIDATION_ASSERT(
node, merged_channel_count.is_dynamic() || static_cast<size_t>(merged_channel_count) > 0)
<< "Data batch channel count and/or filter input channel count is zero.";
NODE_VALIDATION_ASSERT(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0)
<< "Filter output channel count is zero."; << "Filter output channel count is zero.";
Shape data_output_shape = infer_windowed_reduction_output_shape(node, PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape, data_spatial_shape,
data_dilation, data_dilation,
data_padding_below, data_padding_below,
...@@ -178,15 +230,18 @@ std::tuple<element::Type, Shape> ...@@ -178,15 +230,18 @@ std::tuple<element::Type, Shape>
filter_spatial_shape, filter_spatial_shape,
filter_strides, filter_strides,
filter_dilation, filter_dilation,
true) true);
.to_shape();
Shape batch_output_shape(data_batch_shape.size()); PartialShape batch_output_shape(PartialShape::dynamic(spatial_rank + 2));
batch_output_shape[0] = batch_size; batch_output_shape[0] = batch_size;
batch_output_shape[1] = filter_output_channel_count; batch_output_shape[1] = filter_output_channel_count;
std::copy(data_output_shape.begin(), data_output_shape.end(), batch_output_shape.begin() + 2);
return std::make_tuple(et_batch, batch_output_shape); for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
batch_output_shape[i + 2] = data_output_shape[i];
}
return std::make_tuple(et_result, batch_output_shape);
} }
// //
......
...@@ -33,15 +33,15 @@ namespace ngraph ...@@ -33,15 +33,15 @@ namespace ngraph
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape> std::tuple<element::Type, PartialShape>
infer_convolution_forward(const Node* node, infer_convolution_forward(const Node* node,
element::Type et_batch, element::Type et_batch,
element::Type et_filters, element::Type et_filters,
const Shape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& filters_shape, const PartialShape& filters_shape,
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation); const Strides& filter_dilation);
......
...@@ -4529,7 +4529,7 @@ TEST(type_prop, conv_invalid_0d_input) ...@@ -4529,7 +4529,7 @@ TEST(type_prop, conv_invalid_0d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
"(one batch axis, one input-channel axis, " "(one batch axis, one input-channel axis, "
"and at least one spatial dimension)")); "and at least one spatial dimension)"));
} }
...@@ -4554,7 +4554,7 @@ TEST(type_prop, conv_invalid_1d_input) ...@@ -4554,7 +4554,7 @@ TEST(type_prop, conv_invalid_1d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
"(one batch axis, one input-channel axis, " "(one batch axis, one input-channel axis, "
"and at least one spatial dimension)")); "and at least one spatial dimension)"));
} }
...@@ -4579,7 +4579,7 @@ TEST(type_prop, conv_invalid_2d_input) ...@@ -4579,7 +4579,7 @@ TEST(type_prop, conv_invalid_2d_input)
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch must have rank of at least 3 " std::string("Data batch and filters must have rank of at least 3 "
"(one batch axis, one input-channel axis, " "(one batch axis, one input-channel axis, "
"and at least one spatial dimension)")); "and at least one spatial dimension)"));
} }
...@@ -4625,7 +4625,9 @@ TEST(type_prop, conv_invalid_0_input_channels) ...@@ -4625,7 +4625,9 @@ TEST(type_prop, conv_invalid_0_input_channels)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch channel count is zero")); EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Data batch channel count and/or filter input channel count is zero"));
} }
catch (...) catch (...)
{ {
...@@ -4647,12 +4649,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many) ...@@ -4647,12 +4649,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_many)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding "
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), "
"window shape ({3,3,3}), window strides (Strides{1, 1}), and window "
"dilation (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -4674,12 +4671,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few) ...@@ -4674,12 +4671,7 @@ TEST(type_prop, conv_invalid_wrong_number_of_filter_dimensions_too_few)
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), std::string("Data batch and filters rank do not match"));
error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding "
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), "
"window shape ({3}), window strides (Strides{1, 1}), and window dilation "
"(Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -4750,10 +4742,12 @@ TEST(type_prop, conv_invalid_movement_stride_rank) ...@@ -4750,10 +4742,12 @@ TEST(type_prop, conv_invalid_movement_stride_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding " std::string("Ranks for data item shape/filters shape (data batch has shape "
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), " "{6,2,10,10}, so data item rank is 2 and filters have shape {6,2,3,3}, so "
"window shape ({3,3}), window strides (Strides{2, 3, 8}), and window " "filters spatial rank is 2), data dilation (Strides{1, 1}), padding below "
"dilation (Strides{1, 1}) do not match")); "(CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), filter "
"strides (Strides{2, 3, 8}), and filter dilation (Strides{1, 1}) do not "
"match"));
} }
catch (...) catch (...)
{ {
...@@ -4777,10 +4771,12 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank) ...@@ -4777,10 +4771,12 @@ TEST(type_prop, conv_invalid_window_dilation_stride_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding " std::string("Ranks for data item shape/filters shape (data batch has shape "
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), " "{6,2,10,10}, so data item rank is 2 and filters have shape {6,2,3,3}, so "
"window shape ({3,3}), window strides (Strides{2, 3}), and window dilation " "filters spatial rank is 2), data dilation (Strides{1, 1}), padding below "
"(Strides{2, 3, 8}) do not match")); "(CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), filter "
"strides (Strides{2, 3}), and filter dilation (Strides{2, 3, 8}) do not "
"match"));
} }
catch (...) catch (...)
{ {
...@@ -4810,10 +4806,12 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank) ...@@ -4810,10 +4806,12 @@ TEST(type_prop, conv_invalid_data_dilation_stride_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{2, 3, 8}), padding " std::string("Ranks for data item shape/filters shape (data batch has shape "
"{6,2,10,10}, so data item rank is 2 and filters have shape {6,2,3,3}, so "
"filters spatial rank is 2), data dilation (Strides{2, 3, 8}), padding "
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), " "below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), "
"window shape ({3,3}), window strides (Strides{2, 3}), and window dilation " "filter strides (Strides{2, 3}), and filter dilation (Strides{2, 3}) do "
"(Strides{2, 3}) do not match")); "not match"));
} }
catch (...) catch (...)
{ {
...@@ -4842,10 +4840,12 @@ TEST(type_prop, conv_invalid_padding_below_rank) ...@@ -4842,10 +4840,12 @@ TEST(type_prop, conv_invalid_padding_below_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding " std::string(
"below (CoordinateDiff{0, 0, 0}), padding above (CoordinateDiff{0, 0}), " "Ranks for data item shape/filters shape (data batch has shape "
"window shape ({3,3}), window strides (Strides{2, 3}), and window dilation " "{6,2,10,10}, so data item rank is 2 and filters have shape {6,2,3,3}, so "
"(Strides{1, 1}) do not match")); "filters spatial rank is 2), data dilation (Strides{1, 1}), padding below "
"(CoordinateDiff{0, 0, 0}), padding above (CoordinateDiff{0, 0}), filter "
"strides (Strides{2, 3}), and filter dilation (Strides{1, 1}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -4874,10 +4874,12 @@ TEST(type_prop, conv_invalid_padding_above_rank) ...@@ -4874,10 +4874,12 @@ TEST(type_prop, conv_invalid_padding_above_rank)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(
error.what(), error.what(),
std::string("Ranks for data shape ({10,10}), data dilation (Strides{1, 1}), padding " std::string(
"below (CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0, 0}), " "Ranks for data item shape/filters shape (data batch has shape "
"window shape ({3,3}), window strides (Strides{2, 3}), and window dilation " "{6,2,10,10}, so data item rank is 2 and filters have shape {6,2,3,3}, so "
"(Strides{2, 3}) do not match")); "filters spatial rank is 2), data dilation (Strides{1, 1}), padding below "
"(CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0, 0}), filter "
"strides (Strides{2, 3}), and filter dilation (Strides{2, 3}) do not match"));
} }
catch (...) catch (...)
{ {
...@@ -5093,6 +5095,1181 @@ TEST(type_prop, conv_invalid_movement_stride_0) ...@@ -5093,6 +5095,1181 @@ TEST(type_prop, conv_invalid_movement_stride_0)
} }
} }
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Window stride rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape ?, so data "
"item rank is ? and filters have shape ?, so filters spatial rank is ?), "
"data dilation (Strides{1, 1}), padding below (CoordinateDiff{0, 0}), "
"padding above (CoordinateDiff{0, 0}), filter strides (Strides{1, 1, 1}), "
"and filter dilation (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_strides_dim_zero)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 0};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Window stride with dimension zero not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Window strides (Strides{1, 0}) has zero dimension at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Window dilation rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape ?, so data "
"item rank is ? and filters have shape ?, so filters spatial rank is ?), "
"data dilation (Strides{1, 1}), padding below (CoordinateDiff{0, 0}), "
"padding above (CoordinateDiff{0, 0}), filter strides (Strides{1, 1}), and "
"filter dilation (Strides{1, 1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_window_dilation_dim_zero)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 0};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Window dilation with dimension zero not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Window dilation (Strides{1, 0}) has zero dimension at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_below_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Padding below rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape ?, so data "
"item rank is ? and filters have shape ?, so filters spatial rank is ?), "
"data dilation (Strides{1, 1}), padding below (CoordinateDiff{0, 0, 0}), "
"padding above (CoordinateDiff{0, 0}), filter strides (Strides{1, 1}), and "
"filter dilation (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_padding_above_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Padding above rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape ?, so data "
"item rank is ? and filters have shape ?, so filters spatial rank is ?), "
"data dilation (Strides{1, 1}), padding below (CoordinateDiff{0, 0}), "
"padding above (CoordinateDiff{0, 0, 0}), filter strides (Strides{1, 1}), "
"and filter dilation (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Data dilation rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape ?, so data "
"item rank is ? and filters have shape ?, so filters spatial rank is ?), "
"data dilation (Strides{1, 1, 1}), padding below (CoordinateDiff{0, 0}), "
"padding above (CoordinateDiff{0, 0}), filter strides (Strides{1, 1}), and "
"filter dilation (Strides{1, 1}) do not match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_dynamic_data_dilation_dim_zero)
{
PartialShape data_batch_shape{PartialShape::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 0};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Data dilation with dimension zero not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Data dilation (Strides{1, 0}) has zero dimension at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_data_batch_rank_wrong)
{
PartialShape data_batch_shape{PartialShape::dynamic(5)};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Data batch rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Ranks for data item shape/filters shape (data batch has shape "
"{?,?,?,?,?}, so data item rank is 3 and filters have shape ?, so filters "
"spatial rank is ?), data dilation (Strides{1, 1}), padding below "
"(CoordinateDiff{0, 0}), padding above (CoordinateDiff{0, 0}), filter "
"strides (Strides{1, 1}), and filter dilation (Strides{1, 1}) do not "
"match"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_batch_size_known_ok)
{
PartialShape data_batch_shape{
64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_batch_size_known_zero)
{
PartialShape data_batch_shape{
0, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero batch size not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Batch size is zero"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_input_channel_count_known_ok)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_dynamic_input_channel_count_known_zero)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{PartialShape::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero input channel count not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Data batch channel count and/or filter input channel count is zero"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_output_channel_count_known_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{
32, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), 32, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_output_channel_count_known_zero)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{0, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero output channel count not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Filter output channel count is zero"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_input_channel_count_known_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{Dimension::dynamic(), 4, Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_dynamic_rank_static_dynamic_input_channel_count_known_zero)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero input channel count not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Data batch channel count and/or filter input channel count is zero"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_ok)
{
PartialShape data_batch_shape{PartialShape::dynamic(4)};
PartialShape filters_shape{PartialShape::dynamic(4)};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_arg_ranks_mismatch)
{
PartialShape data_batch_shape{PartialShape::dynamic(5)};
PartialShape filters_shape{PartialShape::dynamic(4)};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Argument rank mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data batch and filters rank do not match (data batch "
"shape: {?,?,?,?,?}, filters shape: {?,?,?,?})"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_input_channel_counts_known_ok)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_input_channel_counts_mismatch)
{
PartialShape data_batch_shape{
Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{
Dimension::dynamic(), 22, Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Input channel count mismatch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Data batch channel count (3) does not match filter input channel count (22)"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_known_ok)
{
PartialShape data_batch_shape{64, 3, Dimension::dynamic(), Dimension::dynamic()};
PartialShape filters_shape{100, 3, Dimension::dynamic(), Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, Dimension::dynamic(), Dimension::dynamic()}));
}
TEST(type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_ok)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 5, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 196, Dimension::dynamic()}));
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_filters_too_big)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 201, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Oversize filter not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger "
"than the data shape after padding (dim: 200) at axis 0"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_filters_not_too_big_after_padding)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 201, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{2, 0};
CoordinateDiff padding_above{-1, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 1, Dimension::dynamic()}));
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_filters_not_too_big_after_data_dilation)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 201, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{2, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 199, Dimension::dynamic()}));
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_filters_not_too_big_after_data_dilation_strided)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 201, Dimension::dynamic()};
Strides window_movement_strides{3, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{2, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 67, Dimension::dynamic()}));
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_known_filters_too_big_after_filter_dilation)
{
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 101, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{2, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Oversize filter after window dilation not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Window after dilation has dimension (dim: 201) larger "
"than the data shape after padding (dim: 200) at axis 0"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_zero_data_batch_dim)
{
PartialShape data_batch_shape{64, 3, 200, 0};
PartialShape filters_shape{100, 3, 5, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero dimension in data batch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
"dimension less than 1 (dim: 0) at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_positive_data_batch_dim_after_padding)
{
PartialShape data_batch_shape{64, 3, 200, 0};
PartialShape filters_shape{100, 3, 5, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 2};
CoordinateDiff padding_above{0, -1};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_EQ(conv->get_output_element_type(0), element::f32);
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 196, Dimension::dynamic()}));
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_zero_data_batch_dim_after_padding)
{
PartialShape data_batch_shape{64, 3, 200, 20};
PartialShape filters_shape{100, 3, 5, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, 0};
CoordinateDiff padding_above{0, -20};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Zero padded dimension in data batch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has "
"dimension less than 1 (dim: 0) at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(
type_prop,
conv_partial_rank_static_dynamic_rank_static_dynamic_all_nonspatial_some_spatial_negative_data_batch_dim_after_padding)
{
PartialShape data_batch_shape{64, 3, 200, 20};
PartialShape filters_shape{100, 3, 5, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{0, -1};
CoordinateDiff padding_above{0, -20};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::f32, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::f32, filters_shape);
try
{
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
FAIL() << "Negative padded dimension in data batch not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Data shape after padding and dilation has dimension less "
"than 1 (dim: -1) at axis 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, conv_partial_dynamic_et)
{
// For this test the exact shape parameters are kind of arbitrary---just copied and pasted
// from some known-"OK" test above. We're only concerned about the element types.
PartialShape data_batch_shape{64, 3, 200, Dimension::dynamic()};
PartialShape filters_shape{100, 3, 201, Dimension::dynamic()};
Strides window_movement_strides{1, 1};
Strides window_dilation_strides{1, 1};
CoordinateDiff padding_below{2, 0};
CoordinateDiff padding_above{-1, 0};
Strides data_dilation_strides{1, 1};
auto param0 = make_shared<op::Parameter>(element::dynamic, data_batch_shape);
auto param1 = make_shared<op::Parameter>(element::dynamic, filters_shape);
auto conv = make_shared<op::Convolution>(param0,
param1,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides);
ASSERT_TRUE(conv->get_output_element_type(0).is_dynamic());
ASSERT_TRUE(conv->get_output_partial_shape(0).same_scheme(
PartialShape{64, 100, 1, Dimension::dynamic()}));
}
TEST(type_prop, max_pool_1d_deduce) TEST(type_prop, max_pool_1d_deduce)
{ {
// Deduce type // Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment