Commit ccfcf4f9 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4λ: Convolution and backprops (#1890)

* Implement partial shape/type propagation for Convolution; fail for want of unit tests

* Implement unit tests for partial shapes/types for Convolution
parent 0d693fc3
This diff is collapsed.
......@@ -155,8 +155,12 @@ namespace ngraph
Strides m_data_dilation_strides;
private:
static Strides default_strides(const Node* node, const Shape& data_batch_shape);
static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape);
static Strides default_strides(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
static CoordinateDiff default_padding(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
};
/// \brief Data batch backprop for batched convolution operation.
......
......@@ -125,68 +125,123 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
//
// Infers the output batch shape and element type for convolution fprop.
//
std::tuple<element::Type, Shape>
std::tuple<element::Type, PartialShape>
ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const Shape& data_batch_shape,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& filters_shape,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation)
{
NODE_VALIDATION_ASSERT(node, et_batch == et_filters)
element::Type et_result;
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_batch, et_filters))
<< "Element types for data batch and filters do not match (data batch element type: "
<< et_batch << ", filters element type: " << et_filters << ").";
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3)
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(data batch shape: " << data_batch_shape << ").";
Rank data_batch_filters_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT(
node, Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()))
<< "Data batch and filters rank do not match (data batch shape: " << data_batch_shape
<< ", filters shape: " << filters_shape << ").";
NODE_VALIDATION_ASSERT(node, filters_shape.size() >= 3)
<< "Filters must have rank of at least 3 (one output-channel axis, "
NODE_VALIDATION_ASSERT(node,
data_batch_filters_rank.is_dynamic() ||
static_cast<size_t>(data_batch_filters_rank) >= 3)
<< "Data batch and filters must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) "
<< "(filters shape: " << filters_shape << ").";
<< "(data batch shape: " << data_batch_shape << ", filters shape: " << filters_shape
<< ").";
size_t batch_size = data_batch_shape[0];
size_t data_channel_count = data_batch_shape[1];
Shape data_spatial_shape(data_batch_shape.begin() + 2, data_batch_shape.end());
Rank spatial_rank{Rank::dynamic()};
NODE_VALIDATION_ASSERT(node,
Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()))
<< "Ranks for data item shape/filters shape (data batch has shape " << data_batch_shape
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << " and filters have shape "
<< filters_shape << ", so filters spatial rank is " << (filters_shape.rank() - 2)
<< "), data dilation (" << data_dilation << "), padding below (" << data_padding_below
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides
<< "), and filter dilation (" << filter_dilation << ") do not match.";
Dimension batch_size =
(data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
Dimension data_channel_count =
(data_batch_shape.rank().is_static() ? data_batch_shape[1] : Dimension::dynamic());
PartialShape data_spatial_shape(PartialShape::dynamic(spatial_rank));
Dimension filter_output_channel_count =
(filters_shape.rank().is_static() ? filters_shape[0] : Dimension::dynamic());
Dimension filter_input_channel_count =
(filters_shape.rank().is_static() ? filters_shape[1] : Dimension::dynamic());
PartialShape filter_spatial_shape(PartialShape::dynamic(spatial_rank));
//
// Note: spatial_rank is definitely static at this point.
//
for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
if (data_batch_shape.rank().is_static())
{
data_spatial_shape[i] = data_batch_shape[i + 2];
}
size_t filter_output_channel_count = filters_shape[0];
size_t filter_input_channel_count = filters_shape[1];
Shape filter_spatial_shape(filters_shape.begin() + 2, filters_shape.end());
if (filters_shape.rank().is_static())
{
filter_spatial_shape[i] = filters_shape[i + 2];
}
}
NODE_VALIDATION_ASSERT(node, batch_size > 0) << "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0)
<< "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, data_channel_count > 0) << "Data batch channel count is zero.";
Dimension merged_channel_count;
NODE_VALIDATION_ASSERT(node, data_channel_count == filter_input_channel_count)
NODE_VALIDATION_ASSERT(
node,
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count))
<< "Data batch channel count (" << data_channel_count << ") does not match filter input "
<< "channel count (" << filter_input_channel_count << ").";
NODE_VALIDATION_ASSERT(node, filter_output_channel_count > 0)
NODE_VALIDATION_ASSERT(
node, merged_channel_count.is_dynamic() || static_cast<size_t>(merged_channel_count) > 0)
<< "Data batch channel count and/or filter input channel count is zero.";
NODE_VALIDATION_ASSERT(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0)
<< "Filter output channel count is zero.";
Shape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape,
data_dilation,
data_padding_below,
data_padding_above,
filter_spatial_shape,
filter_strides,
filter_dilation,
true)
.to_shape();
Shape batch_output_shape(data_batch_shape.size());
PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape,
data_dilation,
data_padding_below,
data_padding_above,
filter_spatial_shape,
filter_strides,
filter_dilation,
true);
PartialShape batch_output_shape(PartialShape::dynamic(spatial_rank + 2));
batch_output_shape[0] = batch_size;
batch_output_shape[1] = filter_output_channel_count;
std::copy(data_output_shape.begin(), data_output_shape.end(), batch_output_shape.begin() + 2);
return std::make_tuple(et_batch, batch_output_shape);
for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
batch_output_shape[i + 2] = data_output_shape[i];
}
return std::make_tuple(et_result, batch_output_shape);
}
//
......
......@@ -33,15 +33,15 @@ namespace ngraph
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape>
std::tuple<element::Type, PartialShape>
infer_convolution_forward(const Node* node,
element::Type et_batch,
element::Type et_filters,
const Shape& data_batch_shape,
const PartialShape& data_batch_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& filters_shape,
const PartialShape& filters_shape,
const Strides& filter_strides,
const Strides& filter_dilation);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment