Commit 256a8b6d authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes and Types, Part 4k: AvgPool/MaxPool and backprops (#1871)

* Add merge_rank function

* Update infer_windowed_reduction_output_shape to use PartialShape

* Minor simplification

* Some unit tests and (whaddaya know) fixes for infer_windowed_reduction_output_shape

* Update infer_batched_pooling_forward to use PartialShape

* Update pooling fprop ops for partial shapes

* Update pooling bprop ops for partial shapes

* Add test-failing reminders to implement unit tests for partial shape/type prop for pooling ops

* Add unit tests for partial shape propagation for poolign ops

* Nuke C-style casts for Dimensions/Ranks in validation_util.cpp
parent 7f6f07ee
...@@ -52,6 +52,11 @@ Dimension Dimension::operator+(const Dimension& dim) const ...@@ -52,6 +52,11 @@ Dimension Dimension::operator+(const Dimension& dim) const
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic()); return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic());
} }
Dimension Dimension::operator-(const Dimension& dim) const
{
return (is_static() && dim.is_static() ? m_dimension - size_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator*(const Dimension& dim) const Dimension Dimension::operator*(const Dimension& dim) const
{ {
return ((is_static() && dim.is_static()) return ((is_static() && dim.is_static())
......
...@@ -56,6 +56,16 @@ namespace ngraph ...@@ -56,6 +56,16 @@ namespace ngraph
} }
return m_dimension; return m_dimension;
} }
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator ptrdiff_t() const
{
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert dynamic dimension to ptrdiff_t");
}
return static_cast<ptrdiff_t>(m_dimension);
}
/// \brief Check whether this dimension represents the same scheme as the argument (both /// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal). /// dynamic, or equal).
...@@ -122,6 +132,12 @@ namespace ngraph ...@@ -122,6 +132,12 @@ namespace ngraph
/// dimension with value `size_t(*this)+size_t(dim)`. /// dimension with value `size_t(*this)+size_t(dim)`.
Dimension operator+(const Dimension& dim) const; Dimension operator+(const Dimension& dim) const;
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)-size_t(dim)`.
Dimension operator-(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension. /// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton. /// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if /// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
......
...@@ -40,31 +40,22 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -40,31 +40,22 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types() void op::AvgPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) if (0 == m_window_movement_strides.size())
{ {
return; m_window_movement_strides = Strides(m_window_shape.size(), 1);
} }
auto& arg_shape = get_input_shape(0); if (0 == m_padding_below.size())
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
{ {
m_window_movement_strides = Strides(arg_shape.size() - 2, 1); m_padding_below = Shape(m_window_shape.size(), 0);
} }
if (0 == m_padding_below.size() && arg_shape.size() > 2) if (0 == m_padding_above.size())
{ {
m_padding_below = Shape(arg_shape.size() - 2, 0); m_padding_above = Shape(m_window_shape.size(), 0);
} }
if (0 == m_padding_above.size() && arg_shape.size() > 2) const PartialShape& arg_shape = get_input_partial_shape(0);
{
m_padding_above = Shape(arg_shape.size() - 2, 0);
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
...@@ -125,19 +116,12 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -125,19 +116,12 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types() void op::AvgPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end()); CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end()); CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
Shape forward_result_shape = PartialShape forward_result_shape =
infer_batched_pooling_forward(this, infer_batched_pooling_forward(this,
m_forward_arg_shape, m_forward_arg_shape,
padding_below, padding_below,
...@@ -146,10 +130,15 @@ void op::AvgPoolBackprop::validate_and_infer_types() ...@@ -146,10 +130,15 @@ void op::AvgPoolBackprop::validate_and_infer_types()
m_window_movement_strides, m_window_movement_strides,
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape) const PartialShape& delta_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output " << "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ")."; << "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
// TODO(amprocte): Once m_forward_arg_shape is allowed to be dynamic, we may technically be
// able to infer some extra information from forward_result_shape that was not present in the
// forward arg shape---namely batch size and channel count. Merge that info in.
set_output_type(0, get_input_element_type(0), m_forward_arg_shape); set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
} }
......
...@@ -42,31 +42,22 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -42,31 +42,22 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) if (0 == m_window_movement_strides.size())
{ {
return; m_window_movement_strides = Strides(m_window_shape.size(), 1);
} }
auto& arg_shape = get_input_shape(0); if (0 == m_padding_below.size())
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
<< "Data input shape does not have rank of at least 3 (data input shape: " << arg_shape
<< ").";
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
{ {
m_window_movement_strides = Strides(arg_shape.size() - 2, 1); m_padding_below = Shape(m_window_shape.size(), 0);
} }
if (0 == m_padding_below.size() && arg_shape.size() > 2) if (0 == m_padding_above.size())
{ {
m_padding_below = Shape(arg_shape.size() - 2, 0); m_padding_above = Shape(m_window_shape.size(), 0);
} }
if (0 == m_padding_above.size() && arg_shape.size() > 2) const PartialShape& arg_shape = get_input_partial_shape(0);
{
m_padding_above = Shape(arg_shape.size() - 2, 0);
}
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
// now still take Shape (no negative padding). // now still take Shape (no negative padding).
...@@ -125,17 +116,12 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -125,17 +116,12 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types() void op::MaxPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic()) element::Type forward_arg_et = get_input_element_type(0);
{ element::Type delta_et = get_input_element_type(1);
return;
}
auto forward_arg_et = get_input_element_type(0); element::Type result_et;
auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1);
auto& delta_shape = get_input_shape(1);
NODE_VALIDATION_ASSERT(this, forward_arg_et == delta_et) NODE_VALIDATION_ASSERT(this, element::Type::merge(result_et, forward_arg_et, delta_et))
<< "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et << "Element types for forward argument (" << forward_arg_et << ") and delta (" << delta_et
<< ") do not match."; << ") do not match.";
...@@ -144,18 +130,25 @@ void op::MaxPoolBackprop::validate_and_infer_types() ...@@ -144,18 +130,25 @@ void op::MaxPoolBackprop::validate_and_infer_types()
CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end()); CoordinateDiff padding_below(m_padding_below.begin(), m_padding_below.end());
CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end()); CoordinateDiff padding_above(m_padding_above.begin(), m_padding_above.end());
Shape forward_result_shape = infer_batched_pooling_forward(this, const PartialShape& forward_arg_shape = get_input_partial_shape(0);
forward_arg_shape,
padding_below, PartialShape forward_result_shape = infer_batched_pooling_forward(this,
padding_above, forward_arg_shape,
m_window_shape, padding_below,
m_window_movement_strides, padding_above,
true); m_window_shape,
m_window_movement_strides,
true);
const PartialShape& delta_shape = get_input_partial_shape(1);
NODE_VALIDATION_ASSERT(this, forward_result_shape == delta_shape) NODE_VALIDATION_ASSERT(this, forward_result_shape.compatible(delta_shape))
<< "Inferred forward output shape does not match delta shape (inferred forward output " << "Inferred forward output shape does not match delta shape (inferred forward output "
<< "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ")."; << "shape: " << forward_result_shape << ", delta shape: " << delta_shape << ").";
// TODO(amprocte): We may technically be able to infer some extra information from
// forward_result_shape that was not present in the forward arg shape---namely batch size and
// channel count. Merge that info in.
set_output_type(0, get_input_element_type(0), forward_arg_shape); set_output_type(0, get_input_element_type(0), forward_arg_shape);
} }
......
...@@ -79,6 +79,12 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape) ...@@ -79,6 +79,12 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
} }
} }
PartialShape PartialShape::dynamic(Rank r)
{
return PartialShape(
r.is_static(), std::vector<Dimension>(r.is_static() ? size_t(r) : 0, Dimension::dynamic()));
}
bool PartialShape::compatible(const PartialShape& s) const bool PartialShape::compatible(const PartialShape& s) const
{ {
// If we don't know *this's rank, or we don't know s's rank, they are compatible. // If we don't know *this's rank, or we don't know s's rank, they are compatible.
...@@ -182,6 +188,24 @@ bool PartialShape::refines(const PartialShape& s) const ...@@ -182,6 +188,24 @@ bool PartialShape::refines(const PartialShape& s) const
} }
} }
bool PartialShape::merge_rank(Rank r)
{
if (r.is_dynamic())
{
return true;
}
else if (!m_rank_is_static)
{
m_rank_is_static = true;
m_dimensions = std::vector<Dimension>(size_t(r), Dimension::dynamic());
return true;
}
else
{
return (m_dimensions.size() == size_t(r));
}
}
Shape PartialShape::to_shape() const Shape PartialShape::to_shape() const
{ {
if (is_dynamic()) if (is_dynamic())
......
...@@ -88,9 +88,9 @@ namespace ngraph ...@@ -88,9 +88,9 @@ namespace ngraph
/// \return The rank of the shape. This will be Rank::dynamic() if the rank of /// \return The rank of the shape. This will be Rank::dynamic() if the rank of
/// the shape is dynamic. /// the shape is dynamic.
Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); } Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); }
/// \brief Construct a PartialShape with dynamic rank. /// \brief Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
/// \return A PartialShape with dynamic rank. /// \return A PartialShape with the given rank, and all dimensions (if any) dynamic.
static PartialShape dynamic() { return PartialShape(false, {}); } static PartialShape dynamic(Rank r = Rank::dynamic());
/// \brief Check whether this shape is compatible with the argument, i.e., whether it is /// \brief Check whether this shape is compatible with the argument, i.e., whether it is
/// possible to merge them. /// possible to merge them.
/// \param s The shape to be checked for compatibility with this shape. /// \param s The shape to be checked for compatibility with this shape.
...@@ -152,6 +152,12 @@ namespace ngraph ...@@ -152,6 +152,12 @@ namespace ngraph
/// either `s2[i]` is dynamic, or `s1[i]` == `s2[i]`. /// either `s2[i]` is dynamic, or `s1[i]` == `s2[i]`.
bool refines(const PartialShape& s) const; bool refines(const PartialShape& s) const;
/// \brief Checks that this shape's rank is compatible with `r`, and, if this shape's
/// rank is dynamic and `r` is static, updates this shape to have a rank of `r`
/// with dimensions all dynamic.
/// \return `true` if this shape's rank is compatible with `r`, else `false`.
bool merge_rank(Rank r);
/// \brief Convert a static PartialShape to a Shape. /// \brief Convert a static PartialShape to a Shape.
/// \return A new Shape `s` where `s[i] = size_t((*this)[i])`. /// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
/// \throws std::invalid_argument If this PartialShape is dynamic. /// \throws std::invalid_argument If this PartialShape is dynamic.
...@@ -199,11 +205,10 @@ namespace ngraph ...@@ -199,11 +205,10 @@ namespace ngraph
static bool merge_into(PartialShape& dst, const PartialShape& src); static bool merge_into(PartialShape& dst, const PartialShape& src);
private: private:
// Private constructor so PartialShape::dynamic() can construct a shape with // Private constructor for PartialShape::dynamic().
// m_rank_is_static set to false. PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
PartialShape(bool rank_is_static, std::initializer_list<Dimension> init)
: m_rank_is_static(rank_is_static) : m_rank_is_static(rank_is_static)
, m_dimensions(init) , m_dimensions(dimensions)
{ {
} }
......
This diff is collapsed.
...@@ -23,15 +23,15 @@ ...@@ -23,15 +23,15 @@
namespace ngraph namespace ngraph
{ {
Shape infer_windowed_reduction_output_shape(const Node* node, PartialShape infer_windowed_reduction_output_shape(const Node* node,
const Shape& data_shape, const PartialShape& data_shape,
const Strides& data_dilation, const Strides& data_dilation,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
const Strides& window_dilation, const Strides& window_dilation,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, Shape> std::tuple<element::Type, Shape>
infer_convolution_forward(const Node* node, infer_convolution_forward(const Node* node,
...@@ -45,11 +45,11 @@ namespace ngraph ...@@ -45,11 +45,11 @@ namespace ngraph
const Strides& filter_strides, const Strides& filter_strides,
const Strides& filter_dilation); const Strides& filter_dilation);
Shape infer_batched_pooling_forward(const Node* node, PartialShape infer_batched_pooling_forward(const Node* node,
const Shape& data_batch_shape, const PartialShape& data_batch_shape,
const CoordinateDiff& data_padding_below, const CoordinateDiff& data_padding_below,
const CoordinateDiff& data_padding_above, const CoordinateDiff& data_padding_above,
const Shape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
} }
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment