Commit 256a8b6d authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4k: AvgPool/MaxPool and backprops (#1871)

* Add merge_rank function

* Update infer_windowed_reduction_output_shape to use PartialShape

* Minor simplification

* Some unit tests and (whaddaya know) fixes for infer_windowed_reduction_output_shape

* Update infer_batched_pooling_forward to use PartialShape

* Update pooling fprop ops for partial shapes

* Update pooling bprop ops for partial shapes

* Add test-failing reminders to implement unit tests for partial shape/type prop for pooling ops

* Add unit tests for partial shape propagation for poolign ops

* Nuke C-style casts for Dimensions/Ranks in validation_util.cpp
parent 7f6f07ee
......@@ -52,6 +52,11 @@ Dimension Dimension::operator+(const Dimension& dim) const
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator-(const Dimension& dim) const
{
return (is_static() && dim.is_static() ? m_dimension - size_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator*(const Dimension& dim) const
{
return ((is_static() && dim.is_static())
......
......@@ -56,6 +56,16 @@ namespace ngraph
}
return m_dimension;
}
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator ptrdiff_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to ptrdiff_t");
}
return static_cast<ptrdiff_t>(m_dimension);
}
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
......@@ -122,6 +132,12 @@ namespace ngraph
/// dimension with value `size_t(*this)+size_t(dim)`.
Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)-size_t(dim)`.
Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
......
......@@ -40,31 +40,22 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
if (0 == m_window_movement_strides.size())
{
return;
m_window_movement_strides = Strides(m_window_shape.size(), 1);
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
if (0 == m_padding_below.size())
{
m_window_movement_strides = Strides(arg_shape.size() - 2, 1);
m_padding_below = Shape(m_window_shape.size(), 0);
}
if (0 == m_padding_below.size() && arg_shape.size() > 2)
if (0 == m_padding_above.size())
{
m_padding_below = Shape(arg_shape.size() - 2, 0);
m_padding_above = Shape(m_window_shape.size(), 0);
}
if (0 == m_padding_above.size() && arg_shape.size() > 2)
{
m_padding_above = Shape(arg_shape.size() - 2, 0);
}
const PartialShape& arg_shape = get_input_partial_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
......@@ -125,19 +116,12 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
Shape forward_result_shape =
PartialShape forward_result_shape =
infer_batched_pooling_forward(this,
m_forward_arg_shape,
padding_below,
......@@ -146,10 +130,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
m_window_movement_strides,
m_include_padding_in_avg_computation);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
const PartialShape& delta_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the
// forward arg shape---namely batch size and channel count. Merge that info in.
set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
}
......
......@@ -42,31 +42,22 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
if (0 == m_window_movement_strides.size())
{
return;
m_window_movement_strides = Strides(m_window_shape.size(), 1);
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
if (0 == m_padding_below.size())
{
m_window_movement_strides = Strides(arg_shape.size() - 2, 1);
m_padding_below = Shape(m_window_shape.size(), 0);
}
if (0 == m_padding_below.size() && arg_shape.size() > 2)
if (0 == m_padding_above.size())
{
m_padding_below = Shape(arg_shape.size() - 2, 0);
m_padding_above = Shape(m_window_shape.size(), 0);
}
if (0 == m_padding_above.size() && arg_shape.size() > 2)
{
m_padding_above = Shape(arg_shape.size() - 2, 0);
}
const PartialShape& arg_shape = get_input_partial_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding).
......@@ -125,17 +116,12 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
element::Type forward_arg_et = get_input_element_type(0);
element::Type delta_et = get_input_element_type(1);
auto forward_arg_et = get_input_element_type(0);
auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1);
auto& delta_shape = get_input_shape(1);
element::Type result_et;
NODE_VALIDATION_ASSERT(this, forward_arg_et == delta_et)
NODE_VALIDATION_ASSERT(this, element::Type::merge(result_et, forward_arg_et, delta_et))
<< "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et
<< ") do not match.";
......@@ -144,18 +130,25 @@ void op::MaxPoolBackprop::validate_and_infer_types()
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
Shape forward_result_shape = infer_batched_pooling_forward(this,
forward_arg_shape,
padding_below,
padding_above,
m_window_shape,
m_window_movement_strides,
true);
const PartialShape& forward_arg_shape = get_input_partial_shape(0);
PartialShape forward_result_shape = infer_batched_pooling_forward(this,
forward_arg_shape,
padding_below,
padding_above,
m_window_shape,
m_window_movement_strides,
true);
const PartialShape& delta_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape)
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
// TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and
// channel count. Merge that info in.
set_output_type(0, get_input_element_type(0), forward_arg_shape);
}
......
......@@ -79,6 +79,12 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
}
}
PartialShape PartialShape::dynamic(Rank r)
{
return PartialShape(
r.is_static(), std::vector<Dimension>(r.is_static() ? size_t(r) : 0, Dimension::dynamic()));
}
bool PartialShape::compatible(const PartialShape& s) const
{
// If we don't know *this's rank, or we don't know s's rank, they are compatible.
......@@ -182,6 +188,24 @@ bool PartialShape::refines(const PartialShape& s) const
}
}
bool PartialShape::merge_rank(Rank r)
{
if (r.is_dynamic())
{
return true;
}
else if (!m_rank_is_static)
{
m_rank_is_static = true;
m_dimensions = std::vector<Dimension>(size_t(r), Dimension::dynamic());
return true;
}
else
{
return (m_dimensions.size() == size_t(r));
}
}
Shape PartialShape::to_shape() const
{
if (is_dynamic())
......
......@@ -88,9 +88,9 @@ namespace ngraph
/// \return The rank of the shape. This will be Rank::dynamic() if the rank of
/// the shape is dynamic.
Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); }
/// \brief Construct a PartialShape with dynamic rank.
/// \return A PartialShape with dynamic rank.
static PartialShape dynamic() { return PartialShape(false, {}); }
/// \brief Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
/// \return A PartialShape with the given rank, and all dimensions (if any) dynamic.
static PartialShape dynamic(Rank r = Rank::dynamic());
/// \brief Check whether this shape is compatible with the argument, i.e., whether it is
/// possible to merge them.
/// \param s The shape to be checked for compatibility with this shape.
......@@ -152,6 +152,12 @@ namespace ngraph
/// either `s2[i]` is dynamic, or `s1[i]` == `s2[i]`.
bool refines(const PartialShape& s) const;
/// \brief Checks that this shape's rank is compatible with `r`, and, if this shape's
/// rank is dynamic and `r` is static, updates this shape to have a rank of `r`
/// with dimensions all dynamic.
/// \return `true` if this shape's rank is compatible with `r`, else `false`.
bool merge_rank(Rank r);
/// \brief Convert a static PartialShape to a Shape.
/// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
/// \throws std::invalid_argument If this PartialShape is dynamic.
......@@ -199,11 +205,10 @@ namespace ngraph
static bool merge_into(PartialShape& dst, const PartialShape& src);
private:
// Private constructor so PartialShape::dynamic() can construct a shape with
// m_rank_is_static set to false.
PartialShape(bool rank_is_static, std::initializer_list<Dimension> init)
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(init)
, m_dimensions(dimensions)
{
}
......
This diff is collapsed.
......@@ -23,15 +23,15 @@
namespace ngraph
{
Shape infer_windowed_reduction_output_shape(const Node* node,
const Shape& data_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
PartialShape infer_windowed_reduction_output_shape(const Node* node,
const PartialShape& data_shape,
const Strides& data_dilation,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& window_shape,
const Strides& window_strides,
const Strides& window_dilation,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape>
infer_convolution_forward(const Node* node,
......@@ -45,11 +45,11 @@ namespace ngraph
const Strides& filter_strides,
const Strides& filter_dilation);
Shape infer_batched_pooling_forward(const Node* node,
const Shape& data_batch_shape,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const Shape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed);
PartialShape infer_batched_pooling_forward(const Node* node,
const PartialShape& data_batch_shape,
const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above,
const PartialShape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed);
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment