Commit ccfcf4f9 authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4λ: Convolution and backprops (#1890)

* Implement partial shape/type propagation for Convolution; fail for want of unit tests

* Implement unit tests for partial shapes/types for Convolution
parent 0d693fc3
This diff is collapsed.
...@@ -155,8 +155,12 @@ namespace ngraph ...@@ -155,8 +155,12 @@ namespace ngraph
Strides m_data_dilation_strides; Strides m_data_dilation_strides;
private: private:
static Strides default_strides(const Node* node, const Shape& data_batch_shape); static Strides default_strides(const Node* node,
static CoordinateDiff default_padding(const Node* node, const Shape& data_batch_shape); const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
static CoordinateDiff default_padding(const Node* node,
const PartialShape& data_batch_shape,
const PartialShape& filters_shape);
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
......
...@@ -125,68 +125,123 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node, ...@@ -125,68 +125,123 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
// //
// Infers the output batch shape and element type for convolution fprop. // Infers the output batch shape and element type for convolution fprop.
// //
std::tuple<element::Type, Shape> std::tuple<element::Type, PartialShape>
ngraph::infer_convolution_forward(const Node* node, ngraph::infer_convolution_forward(const Node* node,
element::Type et_batch, element::Type et_batch,
element::Type et_filters, element::Type et_filters,
const Shape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& filters_shape, const PartialShape& filters_shape,
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation) const Strides& filter_dilation)
{ {
NODE_VALIDATION_ASSERT(node, et_batch == et_filters) element::Type et_result;
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_batch, et_filters))
<< "Element types for data batch and filters do not match (data batch element type: " << "Element types for data batch and filters do not match (data batch element type: "
<< et_batch << ", filters element type: " << et_filters << ")."; << et_batch << ", filters element type: " << et_filters << ").";
NODE_VALIDATION_ASSERT(node, data_batch_shape.size() >= 3) Rank data_batch_filters_rank{Rank::dynamic()};
<< "Data batch must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) " NODE_VALIDATION_ASSERT(
<< "(data batch shape: " << data_batch_shape << ")."; node, Rank::merge(data_batch_filters_rank, data_batch_shape.rank(), filters_shape.rank()))
<< "Data batch and filters rank do not match (data batch shape: " << data_batch_shape
<< ", filters shape: " << filters_shape << ").";
NODE_VALIDATION_ASSERT(node, filters_shape.size() >= 3) NODE_VALIDATION_ASSERT(node,
<< "Filters must have rank of at least 3 (one output-channel axis, " data_batch_filters_rank.is_dynamic() ||
static_cast<size_t>(data_batch_filters_rank) >= 3)
<< "Data batch and filters must have rank of at least 3 (one batch axis, "
<< "one input-channel axis, and at least one spatial dimension) " << "one input-channel axis, and at least one spatial dimension) "
<< "(filters shape: " << filters_shape << ")."; << "(data batch shape: " << data_batch_shape << ", filters shape: " << filters_shape
<< ").";
size_t batch_size = data_batch_shape[0]; Rank spatial_rank{Rank::dynamic()};
size_t data_channel_count = data_batch_shape[1]; NODE_VALIDATION_ASSERT(node,
Shape data_spatial_shape(data_batch_shape.begin() + 2, data_batch_shape.end()); Rank::merge(spatial_rank, spatial_rank, data_batch_filters_rank - 2) &&
Rank::merge(spatial_rank, spatial_rank, data_dilation.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_below.size()) &&
Rank::merge(spatial_rank, spatial_rank, data_padding_above.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_strides.size()) &&
Rank::merge(spatial_rank, spatial_rank, filter_dilation.size()))
<< "Ranks for data item shape/filters shape (data batch has shape " << data_batch_shape
<< ", so data item rank is " << (data_batch_shape.rank() - 2) << " and filters have shape "
<< filters_shape << ", so filters spatial rank is " << (filters_shape.rank() - 2)
<< "), data dilation (" << data_dilation << "), padding below (" << data_padding_below
<< "), padding above (" << data_padding_above << "), filter strides (" << filter_strides
<< "), and filter dilation (" << filter_dilation << ") do not match.";
Dimension batch_size =
(data_batch_shape.rank().is_static() ? data_batch_shape[0] : Dimension::dynamic());
Dimension data_channel_count =
(data_batch_shape.rank().is_static() ? data_batch_shape[1] : Dimension::dynamic());
PartialShape data_spatial_shape(PartialShape::dynamic(spatial_rank));
Dimension filter_output_channel_count =
(filters_shape.rank().is_static() ? filters_shape[0] : Dimension::dynamic());
Dimension filter_input_channel_count =
(filters_shape.rank().is_static() ? filters_shape[1] : Dimension::dynamic());
PartialShape filter_spatial_shape(PartialShape::dynamic(spatial_rank));
//
// Note: spatial_rank is definitely static at this point.
//
for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
if (data_batch_shape.rank().is_static())
{
data_spatial_shape[i] = data_batch_shape[i + 2];
}
size_t filter_output_channel_count = filters_shape[0]; if (filters_shape.rank().is_static())
size_t filter_input_channel_count = filters_shape[1]; {
Shape filter_spatial_shape(filters_shape.begin() + 2, filters_shape.end()); filter_spatial_shape[i] = filters_shape[i + 2];
}
}
NODE_VALIDATION_ASSERT(node, batch_size > 0) << "Batch size is zero."; NODE_VALIDATION_ASSERT(node, batch_size.is_dynamic() || static_cast<size_t>(batch_size) > 0)
<< "Batch size is zero.";
NODE_VALIDATION_ASSERT(node, data_channel_count > 0) << "Data batch channel count is zero."; Dimension merged_channel_count;
NODE_VALIDATION_ASSERT(node, data_channel_count == filter_input_channel_count) NODE_VALIDATION_ASSERT(
node,
Dimension::merge(merged_channel_count, data_channel_count, filter_input_channel_count))
<< "Data batch channel count (" << data_channel_count << ") does not match filter input " << "Data batch channel count (" << data_channel_count << ") does not match filter input "
<< "channel count (" << filter_input_channel_count << ")."; << "channel count (" << filter_input_channel_count << ").";
NODE_VALIDATION_ASSERT(node, filter_output_channel_count > 0) NODE_VALIDATION_ASSERT(
node, merged_channel_count.is_dynamic() || static_cast<size_t>(merged_channel_count) > 0)
<< "Data batch channel count and/or filter input channel count is zero.";
NODE_VALIDATION_ASSERT(node,
filter_output_channel_count.is_dynamic() ||
static_cast<size_t>(filter_output_channel_count) > 0)
<< "Filter output channel count is zero."; << "Filter output channel count is zero.";
Shape data_output_shape = infer_windowed_reduction_output_shape(node, PartialShape data_output_shape = infer_windowed_reduction_output_shape(node,
data_spatial_shape, data_spatial_shape,
data_dilation, data_dilation,
data_padding_below, data_padding_below,
data_padding_above, data_padding_above,
filter_spatial_shape, filter_spatial_shape,
filter_strides, filter_strides,
filter_dilation, filter_dilation,
true) true);
.to_shape();
PartialShape batch_output_shape(PartialShape::dynamic(spatial_rank + 2));
Shape batch_output_shape(data_batch_shape.size());
batch_output_shape[0] = batch_size; batch_output_shape[0] = batch_size;
batch_output_shape[1] = filter_output_channel_count; batch_output_shape[1] = filter_output_channel_count;
std::copy(data_output_shape.begin(), data_output_shape.end(), batch_output_shape.begin() + 2);
return std::make_tuple(et_batch, batch_output_shape); for (size_t i = 0; i < static_cast<size_t>(spatial_rank); i++)
{
batch_output_shape[i + 2] = data_output_shape[i];
}
return std::make_tuple(et_result, batch_output_shape);
} }
// //
......
...@@ -33,15 +33,15 @@ namespace ngraph ...@@ -33,15 +33,15 @@ namespace ngraph
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape> std::tuple<element::Type, PartialShape>
infer_convolution_forward(const Node* node, infer_convolution_forward(const Node* node,
element::Type et_batch, element::Type et_batch,
element::Type et_filters, element::Type et_filters,
const Shape& data_batch_shape, const PartialShape& data_batch_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& filters_shape, const PartialShape& filters_shape,
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation); const Strides& filter_dilation);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment