Unverified Commit 7b4be37e authored by Adam Procter's avatar Adam Procter Committed by GitHub

Partial Shapes and Types, Part 4a: Implement partial shape/type validation for…

Partial Shapes and Types, Part 4a: Implement partial shape/type validation for some existing ops (#1756)
parent a7f70651
......@@ -152,6 +152,7 @@ set (SRC
runtime/tensor.cpp
serializer.cpp
shape.cpp
shape_util.cpp
strides.cpp
type/element_type.cpp
util.cpp
......
......@@ -73,79 +73,5 @@ namespace ngraph
}
};
template <typename AXIS_VALUES>
AXIS_VALUES project(const AXIS_VALUES& axis_values, const AxisSet& axes)
{
AXIS_VALUES result;
for (size_t i = 0; i < axis_values.size(); i++)
{
if (axes.find(i) != axes.end())
{
result.push_back(axis_values[i]);
}
}
return result;
}
// Removes some values from a vector of axis values
template <typename AXIS_VALUES>
AXIS_VALUES reduce(const AXIS_VALUES& axis_values, const AxisSet& deleted_axes)
{
AxisSet axes;
for (size_t i = 0; i < axis_values.size(); i++)
{
if (deleted_axes.find(i) == deleted_axes.end())
{
axes.insert(i);
}
}
return project(axis_values, axes);
}
// TODO: check validity, i.e. that the new axis indices are all < axis_values.size()+num_new_axes.
// Add new values at particular axis positions
template <typename AXIS_VALUES>
AXIS_VALUES inject_pairs(const AXIS_VALUES& axis_values,
std::vector<std::pair<size_t, size_t>> new_axis_pos_value_pairs)
{
AXIS_VALUES result;
size_t original_pos = 0;
for (size_t result_pos = 0;
result_pos < axis_values.size() + new_axis_pos_value_pairs.size();
result_pos++)
{
auto search_it = std::find_if(
new_axis_pos_value_pairs.begin(),
new_axis_pos_value_pairs.end(),
[result_pos](std::pair<size_t, size_t> p) { return p.first == result_pos; });
if (search_it == new_axis_pos_value_pairs.end())
{
result.push_back(axis_values[original_pos++]);
}
else
{
result.push_back(search_it->second);
}
}
return result;
}
// Add a new value at a particular axis position
template <typename AXIS_VALUES>
AXIS_VALUES inject(const AXIS_VALUES& axis_values, size_t new_axis_pos, size_t new_axis_val)
{
return inject_pairs(axis_values,
std::vector<std::pair<size_t, size_t>>{
std::pair<size_t, size_t>(new_axis_pos, new_axis_val)});
}
std::ostream& operator<<(std::ostream& s, const Coordinate& coordinate);
}
......@@ -47,9 +47,18 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
}
}
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
Dimension Dimension::operator+(const Dimension& dim) const
{
return (d1.is_static() && d2.is_static() ? size_t(d1) + size_t(d2) : Dimension::dynamic());
return (is_static() && dim.is_static() ? m_dimension + size_t(dim) : Dimension::dynamic());
}
Dimension Dimension::operator*(const Dimension& dim) const
{
return ((is_static() && dim.is_static())
? m_dimension * size_t(dim)
: (is_static() && m_dimension == 0)
? 0
: (dim.is_static() && size_t(dim) == 0) ? 0 : Dimension::dynamic());
}
bool Dimension::compatible(const Dimension& d) const
......@@ -57,6 +66,16 @@ bool Dimension::compatible(const Dimension& d) const
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
}
bool Dimension::relaxes(const Dimension& d) const
{
return (is_dynamic() || (d.is_static() && size_t(*this) == size_t(d)));
}
bool Dimension::refines(const Dimension& d) const
{
return (d.is_dynamic() || (is_static() && size_t(d) == size_t(*this)));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.is_dynamic())
......
......@@ -89,12 +89,54 @@ namespace ngraph
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const;
/// \brief Check whether this dimension is a relaxation of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d1` _relaxes_ (or _is a relaxation of_) `d2` if `d1` and `d2` are static
/// and equal, or `d1` is dynamic.
///
/// `d1.relaxes(d2)` is equivalent to `d2.refines(d1)`.
bool relaxes(const Dimension& d) const;
/// \brief Check whether this dimension is a refinement of the argument.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension relaxes `d`, else `false`.
///
/// A dimension `d2` _refines_ (or _is a refinement of_) `d1` if `d1` and `d2` are static
/// and equal, or `d2` is dynamic.
///
/// `d1.refines(d2)` is equivalent to `d2.relaxes(d1)`.
bool refines(const Dimension& d) const;
/// \brief Create a dynamic dimension.
/// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{std::numeric_limits<size_t>::max()};
/// \brief Addition operator for Dimension.
/// \param dim Right operand for addition.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `size_t(*this)+size_t(dim)`.
Dimension operator+(const Dimension& dim) const;
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// `size_t(*this)*size_t(dim)`.
Dimension operator*(const Dimension& dim) const;
/// \brief Add-into operator for Dimension.
/// \param dim Right operand for addition.
/// \return A reference to `*this`, after updating `*this` to the value `*this + dim`.
Dimension& operator+=(const Dimension& dim) { return (*this = *this + dim); }
/// \brief Multiply-into operator for Dimension.
/// \param dim Right operand for multiplication.
/// \return A reference to `*this`, after updating `*this` to the value `*this * dim`.
Dimension& operator*=(const Dimension& dim) { return (*this = *this * dim); }
private:
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension.
......@@ -108,11 +150,4 @@ namespace ngraph
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions.
/// \param d1 Left operand for addition.
/// \param d2 Right operand for addition.
/// \return Dimension::dynamic() if either of `d1` or `d2` is dynamic; else, a static
/// dimension with value `size_t(d1)+size_t(d2)`.
Dimension operator+(const Dimension& d1, const Dimension& d2);
}
......@@ -133,4 +133,5 @@
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -33,7 +33,8 @@ void op::AllReduce::validate_and_infer_types()
}
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 ||
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64)
<< "Only element types f32 and f64 are supported (argument element type: "
<< get_input_element_type(0) << ").";
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/dequantize.hpp"
#include "ngraph/shape_util.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -30,7 +30,7 @@ op::GetOutputElement::GetOutputElement(const shared_ptr<Node>& arg, size_t n)
<< "Output at index " << m_n << " requested, but argument has only "
<< arg->get_output_size() << " outputs.";
set_output_type(0, arg->get_output_element_type(n), arg->get_output_shape(n));
set_output_type(0, arg->get_output_element_type(n), arg->get_output_partial_shape(n));
}
shared_ptr<Node> op::GetOutputElement::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/quantize.hpp"
#include "ngraph/shape_util.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -35,48 +35,55 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg,
void op::Reshape::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0);
auto input_shape = input.get_shape();
auto input_rank = input_shape.size();
auto& input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, m_input_order.size() == input_rank)
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
for (size_t i = 0; i < input_rank; i++)
// Check that the input axis order is a permutation of (0,...,n-1) for some n.
for (size_t i = 0; i < m_input_order.size(); i++)
{
auto it = find(begin(m_input_order), end(m_input_order), i);
NODE_VALIDATION_ASSERT(this, it != end(m_input_order))
NODE_VALIDATION_ASSERT(
this, find(begin(m_input_order), end(m_input_order), i) != end(m_input_order))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
}
size_t input_shape_product = 1;
for (auto i : input_shape)
// TODO(amprocte): should be possible to move around unknown dims in the input shape.
if (input_rank.is_static())
{
input_shape_product *= i;
}
NODE_VALIDATION_ASSERT(this, m_input_order.size() == size_t(input_rank))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
size_t output_shape_product = 1;
for (auto i : m_output_shape)
{
output_shape_product *= i;
}
for (size_t i = 0; i < size_t(input_rank); i++)
{
auto it = find(begin(m_input_order), end(m_input_order), i);
NODE_VALIDATION_ASSERT(this, it != end(m_input_order))
<< "Input axis order is not a permutation of argument's axis indices (axis order: "
<< m_input_order << ", argument shape: " << input_shape << ").";
}
// TODO(amprocte): make a partial_shape_size() analogous to shape_size().
Dimension input_shape_product = 1;
for (size_t i = 0; i < size_t(input_rank); i++)
{
input_shape_product *= input_shape[i];
}
NODE_VALIDATION_ASSERT(this, input_shape_product == output_shape_product)
<< "Product of output shape dimensions does not match product of argument shape dimensions "
<< "(output shape: " << m_output_shape << ", argument shape: " << input_shape << ").";
if (input_shape_product.is_static())
{
NODE_VALIDATION_ASSERT(this, size_t(input_shape_product) == shape_size(m_output_shape))
<< "Product of output shape dimensions does not match product of argument shape "
"dimensions "
<< "(output shape: " << m_output_shape << ", argument shape: " << input_shape
<< ").";
}
}
if (!std::is_sorted(m_input_order.begin(), m_input_order.end()))
{
m_is_transpose = true;
}
set_output_type(0, input.get_element_type(), m_output_shape);
set_output_type(0, get_input_element_type(0), m_output_shape);
}
shared_ptr<Node> op::Reshape::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -32,20 +32,18 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0);
auto input_rank = input_shape.size();
auto input_shape = get_input_partial_shape(0);
Dimension input_rank = input_shape.rank();
// Make sure all reversed axis indices are valid.
for (size_t axis : m_reversed_axes)
if (input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, axis < input_rank)
<< "Reverse axis (" << axis << ") is out of bounds (argument shape: " << input_shape
<< ").";
// Make sure all reversed axis indices are valid.
for (size_t axis : m_reversed_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reverse axis (" << axis << ") is out of bounds (argument shape: " << input_shape
<< ").";
}
}
set_output_type(0, get_input_element_type(0), input_shape);
......
......@@ -38,30 +38,40 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_shape(1) << ").";
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
NODE_VALIDATION_ASSERT(this, m_batch_axis < get_input_shape(0).size())
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_batch_axis < size_t(input_rank))
<< "Batch axis index (" << m_batch_axis
<< ") is out of bounds (argument shape: " << get_input_shape(0) << ").";
<< ") is out of bounds (argument shape: " << input_shape << ").";
NODE_VALIDATION_ASSERT(this, m_seq_axis < get_input_shape(0).size())
NODE_VALIDATION_ASSERT(this, input_rank.is_dynamic() || m_seq_axis < size_t(input_rank))
<< "Sequence axis index (" << m_seq_axis
<< ") is out of bounds (argument shape: " << get_input_shape(0) << ").";
<< ") is out of bounds (argument shape: " << input_shape << ").";
auto indices_shape = get_input_partial_shape(1);
auto indices_rank = indices_shape.rank();
NODE_VALIDATION_ASSERT(this, get_input_shape(0)[m_batch_axis] == get_input_shape(1)[0])
<< "Sequence length (" << get_input_shape(1)[0] << ") is not equal to batch axis "
<< "dimension (" << get_input_shape(0)[m_batch_axis]
<< ") (argument shape: " << get_input_shape(0)
<< ", sequence indices shape: " << get_input_shape(1) << ").";
NODE_VALIDATION_ASSERT(this, indices_rank.is_dynamic() || size_t(indices_rank) == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_partial_shape(1) << ").";
PartialShape output_shape{input_shape};
if (input_rank.is_static() && indices_rank.is_static())
{
Dimension merged_sequence_length;
NODE_VALIDATION_ASSERT(
this,
Dimension::merge(merged_sequence_length, input_shape[m_batch_axis], indices_shape[0]))
<< "Sequence length (" << indices_shape[0] << ") is not equal to batch axis "
<< "dimension (" << input_shape[m_batch_axis] << ") (argument shape: " << input_shape
<< ", sequence indices shape: " << indices_shape << ").";
output_shape[m_batch_axis] = merged_sequence_length;
}
set_output_type(0, get_input_element_type(0), get_input_shape(0));
set_output_type(0, get_input_element_type(0), output_shape);
}
shared_ptr<Node> op::ReverseSequence::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -29,29 +29,32 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
void op::util::ArithmeticReduction::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_partial_shape(0);
auto input_rank = input_shape.rank();
auto input_shape = get_input_shape(0);
PartialShape result_shape{PartialShape::dynamic()};
for (auto axis : m_reduction_axes)
if (input_rank.is_static())
{
NODE_VALIDATION_ASSERT(this, axis < input_shape.size())
<< "Reduction axis (" << axis << ") is out of bounds "
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
<< ")";
}
std::vector<Dimension> dims;
Shape result_shape;
for (auto axis : m_reduction_axes)
{
NODE_VALIDATION_ASSERT(this, axis < size_t(input_rank))
<< "Reduction axis (" << axis << ") is out of bounds "
<< "(argument shape: " << input_shape << ", reduction axes: " << m_reduction_axes
<< ")";
}
for (size_t i = 0; i < input_shape.size(); i++)
{
if (m_reduction_axes.count(i) == 0)
for (size_t i = 0; i < size_t(input_rank); i++)
{
result_shape.push_back(input_shape.at(i));
if (m_reduction_axes.count(i) == 0)
{
dims.push_back(input_shape[i]);
}
}
result_shape = PartialShape(dims);
}
set_output_type(0, get_input_element_type(0), result_shape);
......
......@@ -136,6 +136,52 @@ bool PartialShape::same_scheme(const PartialShape& s) const
}
}
bool PartialShape::relaxes(const PartialShape& s) const
{
if (rank().is_dynamic())
{
return true;
}
else if (s.rank().is_static() && size_t(rank()) == size_t(s.rank()))
{
bool all_relax = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
all_relax &= ((*this)[i].relaxes(s[i]));
}
return all_relax;
}
else
{
return false;
}
}
bool PartialShape::refines(const PartialShape& s) const
{
if (s.rank().is_dynamic())
{
return true;
}
else if (rank().is_static() && size_t(rank()) == size_t(s.rank()))
{
bool all_refine = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
all_refine &= ((*this)[i].refines(s[i]));
}
return all_refine;
}
else
{
return false;
}
}
Shape PartialShape::to_shape() const
{
if (is_dynamic())
......
......@@ -112,6 +112,46 @@ namespace ngraph
/// `s1[i]` represents the same scheme as `s2[i]` (see Dimension::same_scheme()).
bool same_scheme(const PartialShape& s) const;
/// \brief Check whether this shape is a relaxation of the argument.
/// \param s The shape which is being compared against this shape.
/// \return `true` if this shape relaxes `s`, else `false`.
///
/// Intuitively, a PartialShape `s1` is said to _relax_ `s2` (or _is a
/// relaxation_ of `s2`) if it is "more permissive" than `s2`. In other
/// words, `s1` is a relaxation of `s2` if anything you can form by
/// plugging things into the dynamic dimensions of `s2` is also
/// something you can form by plugging things into the dynamic
/// dimensions of `s1`, but not necessarily the other way around.
///
/// `s1.relaxes(s2)` is equivalent to `s2.refines(s1)`.
///
/// Formally, PartialShape `s1` is said to _relax_ PartialShape `s2`
/// if:
/// \li `s1` has dynamic rank, or
/// \li `s1` and `s2` both have static rank `r`, and for every `i` from `0` to `r-1`,
/// either `s1[i]` is dynamic, or `s1[i]` == `s2[i]`.
bool relaxes(const PartialShape& s) const;
/// \brief Check whether this shape is a refinement of the argument.
/// \param s The shape which is being compared against this shape.
/// \return `true` if this shape refines `s`, else `false`.
///
/// Intuitively, a PartialShape `s1` is said to _relax_ `s2` (or _is a
/// relaxation_ of `s2`) if it is "less permissive" than `s2`. In other
/// words, `s1` is a relaxation of `s2` if anything you can form by
/// plugging things into the dynamic dimensions of `s1` is also
/// something you can form by plugging things into the dynamic
/// dimensions of `s2`, but not necessarily the other way around.
///
/// `s1.refines(s2)` is equivalent to `s2.relaxes(s1)`.
///
/// Formally, PartialShape `s1` is said to _refine_ PartialShape `s2`
/// if:
/// \li `s2` has dynamic rank, or
/// \li `s1` and `s2` both have static rank `r`, and for every `i` from `0` to `r-1`,
/// either `s2[i]` is dynamic, or `s1[i]` == `s2[i]`.
bool refines(const PartialShape& s) const;
/// \brief Convert a static PartialShape to a Shape.
/// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
/// \throws std::invalid_argument If this PartialShape is dynamic.
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include <utility>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include <limits>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include <limits>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
#ifdef WIN32
#undef min
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include <functional>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -20,6 +20,7 @@
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
......@@ -19,6 +19,7 @@
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph
{
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "ngraph/shape_util.hpp"
using namespace ngraph;
template <>
PartialShape ngraph::project(const PartialShape& shape, const AxisSet& axes)
{
if (shape.rank().is_dynamic())
{
return shape;
}
else
{
std::vector<Dimension> result_dims;
for (size_t i = 0; i < size_t(shape.rank()); i++)
{
if (axes.find(i) != axes.end())
{
result_dims.push_back(shape[i]);
}
}
return PartialShape(result_dims);
}
}
template <>
PartialShape ngraph::reduce(const PartialShape& shape, const AxisSet& deleted_axes)
{
if (shape.rank().is_dynamic())
{
return shape;
}
else
{
AxisSet axes;
for (size_t i = 0; i < size_t(shape.rank()); i++)
{
if (deleted_axes.find(i) == deleted_axes.end())
{
axes.insert(i);
}
}
return project(shape, axes);
}
}
template <>
PartialShape
ngraph::inject_pairs(const PartialShape& shape,
std::vector<std::pair<size_t, Dimension>> new_axis_pos_value_pairs)
{
if (shape.rank().is_dynamic())
{
return shape;
}
else
{
std::vector<Dimension> result_dims;
size_t original_pos = 0;
for (size_t result_pos = 0;
result_pos < size_t(shape.rank()) + new_axis_pos_value_pairs.size();
result_pos++)
{
auto search_it = std::find_if(
new_axis_pos_value_pairs.begin(),
new_axis_pos_value_pairs.end(),
[result_pos](std::pair<size_t, Dimension> p) { return p.first == result_pos; });
if (search_it == new_axis_pos_value_pairs.end())
{
result_dims.push_back(shape[original_pos++]);
}
else
{
result_dims.push_back(search_it->second);
}
}
return PartialShape{result_dims};
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/partial_shape.hpp"
namespace ngraph
{
template <typename AXIS_VALUES>
AXIS_VALUES project(const AXIS_VALUES& axis_values, const AxisSet& axes)
{
AXIS_VALUES result;
for (size_t i = 0; i < axis_values.size(); i++)
{
if (axes.find(i) != axes.end())
{
result.push_back(axis_values[i]);
}
}
return result;
}
template <>
PartialShape project(const PartialShape& shape, const AxisSet& axes);
// Removes some values from a vector of axis values
template <typename AXIS_VALUES>
AXIS_VALUES reduce(const AXIS_VALUES& axis_values, const AxisSet& deleted_axes)
{
AxisSet axes;
for (size_t i = 0; i < axis_values.size(); i++)
{
if (deleted_axes.find(i) == deleted_axes.end())
{
axes.insert(i);
}
}
return project(axis_values, axes);
}
template <>
PartialShape reduce(const PartialShape& shape, const AxisSet& deleted_axes);
// TODO: check validity, i.e. that the new axis indices are all < axis_values.size()+num_new_axes.
// Add new values at particular axis positions
template <typename AXIS_VALUES, typename AXIS_VALUE>
AXIS_VALUES inject_pairs(const AXIS_VALUES& axis_values,
std::vector<std::pair<size_t, AXIS_VALUE>> new_axis_pos_value_pairs)
{
AXIS_VALUES result;
size_t original_pos = 0;
for (size_t result_pos = 0;
result_pos < axis_values.size() + new_axis_pos_value_pairs.size();
result_pos++)
{
// Would be nice to use std::find_if here but would rather not #include <algorithm> in
// this header
auto search_it = new_axis_pos_value_pairs.begin();
while (search_it != new_axis_pos_value_pairs.end())
{
if (search_it->first == result_pos)
{
break;
}
++search_it;
}
if (search_it == new_axis_pos_value_pairs.end())
{
result.push_back(axis_values[original_pos++]);
}
else
{
result.push_back(search_it->second);
}
}
return result;
}
template <>
PartialShape inject_pairs(const PartialShape& shape,
std::vector<std::pair<size_t, Dimension>> new_axis_pos_value_pairs);
// Add a new value at a particular axis position
template <typename AXIS_VALUES, typename AXIS_VALUE>
AXIS_VALUES inject(const AXIS_VALUES& axis_values, size_t new_axis_pos, AXIS_VALUE new_axis_val)
{
return inject_pairs(axis_values,
std::vector<std::pair<size_t, AXIS_VALUE>>{
std::pair<size_t, AXIS_VALUE>(new_axis_pos, new_axis_val)});
}
}
......@@ -484,3 +484,309 @@ TEST(partial_shape, partial_shape_merge_both_static_different_rank)
const PartialShape s2{1, 2, 3, 4};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, dim_pluseq_left_dynamic)
{
Dimension d1{Dimension::dynamic()};
Dimension d2{2};
d1 += d2;
ASSERT_TRUE(d1.is_dynamic());
}
TEST(partial_shape, dim_pluseq_right_dynamic)
{
Dimension d1{2};
Dimension d2{Dimension::dynamic()};
d1 += d2;
ASSERT_TRUE(d1.is_dynamic());
}
TEST(partial_shape, dim_pluseq_both_static)
{
Dimension d1{3};
Dimension d2{2};
d1 += d2;
ASSERT_TRUE(d1.is_static());
ASSERT_EQ(size_t(d1), 5);
}
TEST(partial_shape, dim_timeseq_left_dynamic_right_nonzero)
{
Dimension d1{Dimension::dynamic()};
Dimension d2{2};
d1 *= d2;
ASSERT_TRUE(d1.is_dynamic());
}
TEST(partial_shape, dim_timeseq_left_dynamic_right_zero)
{
Dimension d1{Dimension::dynamic()};
Dimension d2{0};
d1 *= d2;
ASSERT_TRUE(d1.is_static());
ASSERT_EQ(size_t(d1), 0);
}
TEST(partial_shape, dim_timeseq_right_dynamic_left_nonzero)
{
Dimension d1{2};
Dimension d2{Dimension::dynamic()};
d1 *= d2;
ASSERT_TRUE(d1.is_dynamic());
}
TEST(partial_shape, dim_timeseq_right_dynamic_left_zero)
{
Dimension d1{0};
Dimension d2{Dimension::dynamic()};
d1 *= d2;
ASSERT_TRUE(d1.is_static());
ASSERT_EQ(size_t(d1), 0);
}
TEST(partial_shape, dim_timeseq_both_static)
{
Dimension d1{3};
Dimension d2{2};
d1 *= d2;
ASSERT_TRUE(d1.is_static());
ASSERT_EQ(size_t(d1), 6);
}
TEST(partial_shape, dim_relaxes_refines_dyn_dyn)
{
Dimension d1{Dimension::dynamic()};
Dimension d2{Dimension::dynamic()};
ASSERT_TRUE(d1.refines(d2));
ASSERT_TRUE(d1.relaxes(d2));
ASSERT_TRUE(d2.refines(d1));
ASSERT_TRUE(d2.relaxes(d1));
}
TEST(partial_shape, dim_relaxes_refines_dyn_static)
{
Dimension d1{Dimension::dynamic()};
Dimension d2{3};
ASSERT_FALSE(d1.refines(d2));
ASSERT_TRUE(d1.relaxes(d2));
ASSERT_TRUE(d2.refines(d1));
ASSERT_FALSE(d2.relaxes(d1));
}
TEST(partial_shape, dim_relaxes_refines_static_static_eq)
{
Dimension d1{3};
Dimension d2{3};
ASSERT_TRUE(d1.refines(d2));
ASSERT_TRUE(d1.relaxes(d2));
ASSERT_TRUE(d2.refines(d1));
ASSERT_TRUE(d2.relaxes(d1));
}
TEST(partial_shape, dim_relaxes_refines_static_static_not_eq)
{
Dimension d1{3};
Dimension d2{4};
ASSERT_FALSE(d1.refines(d2));
ASSERT_FALSE(d1.relaxes(d2));
ASSERT_FALSE(d2.refines(d1));
ASSERT_FALSE(d2.relaxes(d1));
}
TEST(partial_shape, partial_shape_relaxes_refines_rank_dynamic_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_TRUE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_rank_dynamic_rank_static_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2{3, Dimension::dynamic(), 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_rank_dynamic_static)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2{3, 5, 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape,
partial_shape_relaxes_refines_rank_dynamic_static_rank_dynamic_static_incompatible)
{
PartialShape s1{3, 5, Dimension::dynamic(), 9};
PartialShape s2{4, Dimension::dynamic(), 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_FALSE(s1.relaxes(s2));
ASSERT_FALSE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape,
partial_shape_relaxes_refines_rank_dynamic_static_rank_dynamic_static_compatible_neither)
{
PartialShape s1{3, 5, Dimension::dynamic(), 9};
PartialShape s2{3, Dimension::dynamic(), 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_FALSE(s1.relaxes(s2));
ASSERT_FALSE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape,
partial_shape_relaxes_refines_rank_dynamic_static_rank_dynamic_static_compatible_one_way)
{
PartialShape s1{3, Dimension::dynamic(), Dimension::dynamic(), 9};
PartialShape s2{3, Dimension::dynamic(), 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape,
partial_shape_relaxes_refines_rank_dynamic_static_rank_dynamic_static_compatible_both_ways)
{
PartialShape s1{3, Dimension::dynamic(), 7, 9};
PartialShape s2{3, Dimension::dynamic(), 7, 9};
ASSERT_TRUE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_TRUE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_rank_dynamic_static_static_incompatible)
{
PartialShape s1{3, Dimension::dynamic(), 7, 9};
PartialShape s2{4, 5, 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_FALSE(s1.relaxes(s2));
ASSERT_FALSE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_rank_dynamic_static_static_compatible)
{
PartialShape s1{3, Dimension::dynamic(), 7, 9};
PartialShape s2{3, 5, 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_static_static_eq)
{
PartialShape s1{3, 5, 7, 9};
PartialShape s2{3, 5, 7, 9};
ASSERT_TRUE(s1.refines(s2));
ASSERT_TRUE(s1.relaxes(s2));
ASSERT_TRUE(s2.refines(s1));
ASSERT_TRUE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_relaxes_refines_static_static_not_eq)
{
PartialShape s1{3, 5, 7, 9};
PartialShape s2{4, 5, 7, 9};
ASSERT_FALSE(s1.refines(s2));
ASSERT_FALSE(s1.relaxes(s2));
ASSERT_FALSE(s2.refines(s1));
ASSERT_FALSE(s2.relaxes(s1));
}
TEST(partial_shape, partial_shape_project_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2 = project(s1, AxisSet{284, 0, 103});
ASSERT_TRUE(s2.rank().is_dynamic());
}
TEST(partial_shape, partial_shape_project_rank_static_dynamic)
{
PartialShape s1{Dimension::dynamic(), 2, Dimension::dynamic(), 3};
PartialShape s2 = project(s1, AxisSet{0, 3});
ASSERT_TRUE(s2.same_scheme(PartialShape{Dimension::dynamic(), 3}));
}
TEST(partial_shape, partial_shape_reduce_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2 = reduce(s1, AxisSet{284, 0, 103});
ASSERT_TRUE(s2.rank().is_dynamic());
}
TEST(partial_shape, partial_shape_reduce_rank_static_dynamic)
{
PartialShape s1{Dimension::dynamic(), 2, Dimension::dynamic(), 3};
PartialShape s2 = reduce(s1, AxisSet{0, 3});
ASSERT_TRUE(s2.same_scheme(PartialShape{2, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_inject_pairs_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
PartialShape s2 = inject_pairs(
s1, std::vector<std::pair<size_t, Dimension>>{{0, Dimension::dynamic()}, {207, 909}});
ASSERT_TRUE(s2.rank().is_dynamic());
}
TEST(partial_shape, partial_shape_inject_pairs_rank_static)
{
PartialShape s1{1, Dimension::dynamic()};
PartialShape s2 =
inject_pairs(s1,
std::vector<std::pair<size_t, Dimension>>{
{0, Dimension::dynamic()}, {2, 909}, {4, Dimension::dynamic()}});
ASSERT_TRUE(s2.same_scheme(
PartialShape{Dimension::dynamic(), 1, 909, Dimension::dynamic(), Dimension::dynamic()}));
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment