Commit 0563a3cf authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 3: Framework for partial shape/element type validation (#1728)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item
parent 631d7253
......@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name)
: m_element_type(element_type)
, m_shape(pshape.is_complete() ? pshape.to_shape() : Shape{})
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_name(name)
{
......@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape)
{
if (pshape.is_complete())
if (pshape.is_static())
{
m_shape = pshape.to_shape();
}
......@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const Shape& descriptor::Tensor::get_shape() const
{
if (m_partial_shape.is_complete())
if (m_partial_shape.is_static())
{
return m_shape;
}
else
{
throw std::invalid_argument(
"get_shape was called on a descriptor::Tensor with incomplete shape");
"get_shape was called on a descriptor::Tensor with dynamic shape");
}
}
......
......@@ -25,11 +25,11 @@ using namespace ngraph;
Dimension::Dimension(size_t dimension)
: m_dimension(dimension)
{
if (dimension == s_undetermined_val)
if (dimension == s_dynamic_val)
{
std::stringstream ss;
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_undetermined_val
<< " to Dimension: this value is used internally to represent an undetermined "
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_dynamic_val
<< " to Dimension: this value is used internally to represent a dynamic "
"dimension.";
throw std::invalid_argument(ss.str());
}
......@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension)
std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{
if (dimension.is_determined())
if (dimension.is_static())
{
return (str << size_t(dimension));
}
......@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
{
return (d1.is_determined() && d2.is_determined() ? size_t(d1) + size_t(d2)
: Dimension::undetermined());
return (d1.is_static() && d2.is_static() ? size_t(d1) + size_t(d2) : Dimension::dynamic());
}
bool Dimension::compatible(const Dimension& d) const
{
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d));
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.is_dynamic())
{
dst = d2;
return true;
}
else if (d2.is_dynamic())
{
dst = d1;
return true;
}
else if (size_t(d1) != size_t(d2))
{
return false;
}
else
{
dst = d1;
return true;
}
}
......@@ -22,56 +22,97 @@
namespace ngraph
{
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object.
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is
/// constructed with Dimension() or Dimension::undetermined().
/// Static dimensions may be implicitly converted from size_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class Dimension
{
public:
/// \brief Constructs a known dimension.
///
/// Requires that dimension != s_undetermined_val. If that condition does not hold,
/// throws std::invalid_argument.
/// \brief Construct a static dimension.
/// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension);
/// \brief Constructs an unknown dimension.
Dimension() { m_dimension = s_undetermined_val; }
/// \brief Returns true if this dimension is determined.
bool is_determined() const { return m_dimension != s_undetermined_val; }
/// \brief Converts this dimension to size_t. If the dimension is undetermined, throws
/// std::invalid_argument.
/// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; }
/// \brief Check whether this dimension is static.
/// \return `true` if the dimension is static, else `false`.
bool is_static() const { return m_dimension != s_dynamic_val; }
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator size_t() const
{
if (!is_determined())
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert unknown dimension to size_t");
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
return m_dimension;
}
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions
/// is undetermined, or both dimensions are determined and equal.
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const
{
return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim));
}
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension.
static Dimension undetermined() { return Dimension(); }
/// \brief Constant for the value used internally to represent an unknown dimension.
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()};
/// \brief Create a dynamic dimension.
/// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{std::numeric_limits<size_t>::max()};
private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension.
size_t m_dimension;
};
/// \brief Inserts a human-readable representation of "dimension" into "str".
/// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns
/// Dimension::undetermined().
/// \param d1 Left operand for addition.
/// \param d2 Right operand for addition.
/// \return Dimension::dynamic() if either of `d1` or `d2` is dynamic; else, a static
/// dimension with value `size_t(d1)+size_t(d2)`.
Dimension operator+(const Dimension& d1, const Dimension& d2);
}
......@@ -52,9 +52,7 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
......
......@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n)
for (size_t i = m_outputs.size(); i < n; ++i)
{
auto tensor_descriptor = make_shared<descriptor::Tensor>(
element::unspecified, Shape(), get_name() + "_" + to_string(i));
element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_descriptor);
}
}
......@@ -84,9 +84,9 @@ void Node::validate_and_infer_types()
{
}
void Node::set_output_type(size_t i, const element::Type& element_type, const Shape& shape)
void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, shape);
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
}
std::deque<descriptor::Output>& Node::get_outputs()
......@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const
return out << get_name();
}
static std::string pretty_element_type(const element::Type& et)
{
if (et.is_dynamic())
{
return "?";
}
else
{
return et.c_type_string();
}
}
std::ostream& Node::write_long_description(std::ostream& out) const
{
out << description() << '[' << get_name() << "](";
string sep = "";
for (auto arg : get_arguments())
{
out << sep << NodeDescription(*arg, true);
out << sep << NodeDescription(*arg, true) << ": "
<< pretty_element_type(arg->get_output_element_type(0))
<< arg->get_output_partial_shape(0) << "";
sep = ", ";
}
out << ")";
......@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args;
}
void Node::validate_and_infer_elementwise(element::Type result_type)
std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_args()
{
const element::Type& element_type = get_input_element_type(0);
const Shape& shape = get_input_shape(0);
element::Type element_type = get_input_element_type(0);
PartialShape pshape = get_input_partial_shape(0);
if (get_input_size() > 1)
{
for (size_t i = 1; i < get_input_size(); ++i)
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == element_type)
<< "Argument 0 element type " << element_type
<< " differs in element type from argument " << i << " " << *get_argument(i)
<< " element type " << get_input_element_type(i);
NODE_VALIDATION_ASSERT(
this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == shape)
<< "Argument 0 shape " << shape << " differs in shape from argument " << i << " "
<< *get_argument(i) << " shape " << get_input_shape(i);
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< "Argument shapes are inconsistent.";
}
}
set_output_type(0, result_type, shape);
return std::make_tuple(element_type, pshape);
}
void Node::validate_and_infer_elementwise_arithmetic()
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: "
<< get_input_element_type(0) << ").";
validate_and_infer_elementwise(get_input_element_type(0));
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
set_output_type(0, args_et, args_pshape);
}
void Node::validate_and_infer_elementwise_logical()
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == element::boolean)
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type "
<< get_input_element_type(0) << ".";
validate_and_infer_elementwise(get_input_element_type(0));
<< args_et << ".";
set_output_type(0, element::boolean, args_pshape);
}
bool Node::validate_punt_if_dynamic()
{
bool any_dynamic = false;
for (auto& input : m_inputs)
{
any_dynamic |= input.get_partial_shape().is_dynamic();
any_dynamic |= input.get_element_type().is_dynamic();
}
if (any_dynamic)
{
for (size_t i = 0; i < get_output_size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
return true;
}
else
{
return false;
}
}
......@@ -22,6 +22,7 @@
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <typeindex>
#include <unordered_map>
#include <unordered_set>
......@@ -94,14 +95,21 @@ namespace ngraph
// Called in constructors during transition
void constructor_validate_and_infer_types();
void validate_and_infer_elementwise(element::Type result_type);
void validate_and_infer_elementwise()
{
validate_and_infer_elementwise(get_input_element_type(0));
}
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args();
void validate_and_infer_elementwise_arithmetic();
void validate_and_infer_elementwise_logical();
// Temporary hack while partial shape propagation is being implemented. If any input has
// dynamic shape or dynamic element type, sets all outputs to have a shape of dynamic
// rank and dynamic element type. Ops where we haven't yet implemented partial shape
// propagation can add this boilerplate at the top of their validate_and_infer_types():
//
// if (validate_punt_if_dynamic())
// {
// return;
// }
bool validate_punt_if_dynamic();
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
......@@ -125,7 +133,9 @@ namespace ngraph
return std::type_index(typeid(*this)) == std::type_index(typeid(*n));
}
void set_output_type(size_t i, const element::Type& element_type, const Shape& shape);
void set_output_type(size_t i,
const element::Type& element_type,
const PartialShape& pshape);
bool is_parameter() const;
virtual bool is_output() const;
......
......@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64)
......
......@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
......@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
......
......@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
m_bn_input_shape = get_input_shape(INPUT);
NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
......@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
set_output_size(3);
NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
......
......@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
infer_shape();
Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
......
......@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0);
......
......@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
void op::Convolution::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0);
auto& filters_shape = get_input_shape(1);
......@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
......@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
......
......@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input,
void op::Dequantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......
......@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0,
void op::Dot::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
......
......@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
......@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto forward_arg_et = get_input_element_type(0);
auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1);
......
......@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg)
constructor_validate_and_infer_types();
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::Not::validate_and_infer_types()
{
validate_and_infer_elementwise();
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
}
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -22,11 +22,11 @@ using namespace std;
using namespace ngraph;
op::Parameter::Parameter(const element::Type& element_type,
const Shape& shape,
const PartialShape& pshape,
const bool cacheable)
: Op("Parameter", {})
, m_cacheable(cacheable)
, m_shape(shape)
, m_partial_shape(pshape)
, m_element_type(element_type)
{
constructor_validate_and_infer_types();
......@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type,
void op::Parameter::validate_and_infer_types()
{
Op::validate_and_infer_types();
set_output_type(0, m_element_type, m_shape);
set_output_type(0, m_element_type, m_partial_shape);
}
shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Parameter>(m_element_type, m_shape);
return make_shared<Parameter>(m_element_type, m_partial_shape);
}
void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -38,10 +38,10 @@ namespace ngraph
/// \brief Constructions a tensor view-typed parameter node.
///
/// \param element_type The element type of the parameter.
/// \param shape The shape of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const Shape& shape,
const PartialShape& pshape,
const bool cacheable = false);
void validate_and_infer_types() override;
......@@ -52,7 +52,7 @@ namespace ngraph
protected:
bool m_cacheable;
Shape m_shape;
PartialShape m_partial_shape;
element::Type m_element_type;
};
}
......
......@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input,
void op::Quantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......
......@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg,
void op::Reshape::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0);
auto input_shape = input.get_shape();
auto input_rank = input_shape.size();
......
......@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types()
// always borrow the placement conf even the default one
set_placement(get_argument(0)->get_placement());
set_output_type(0, get_input_element_type(0), get_input_shape(0));
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0);
auto input_rank = input_shape.size();
......
......@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_shape(1) << ").";
......
......@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size())
{
m_strides = Strides(m_lower_bounds.size(), 1);
......
......@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg,
void op::TopK::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0);
auto rank = input.get_shape().size();
......
......@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
void op::util::ArithmeticReduction::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0);
for (auto axis : m_reduction_axes)
......
......@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string&
void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{
validate_and_infer_elementwise(element::boolean);
auto args_et_pshape = validate_and_infer_elementwise_args();
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);
}
......@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape)
m_dimensions.assign(shape.begin(), shape.end());
}
bool ngraph::PartialShape::is_complete() const
bool ngraph::PartialShape::is_static() const
{
return m_rank_is_determined &&
std::all_of(m_dimensions.begin(), m_dimensions.end(), [](const Dimension& d) {
return d.is_determined();
});
return m_rank_is_static && std::all_of(m_dimensions.begin(),
m_dimensions.end(),
[](const Dimension& d) { return d.is_static(); });
}
PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{
if (!s1.rank_is_determined() || !s2.rank_is_determined())
if (s1.rank().is_dynamic() || s2.rank().is_dynamic())
{
return PartialShape::undetermined();
return PartialShape::dynamic();
}
if (!s1.rank().compatible(s2.rank()))
......@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
}
PartialShape result{};
result.m_rank_is_determined = true;
result.m_rank_is_static = true;
for (size_t i = 0; i < s1.m_dimensions.size(); i++)
{
result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]);
......@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{
if (shape.m_rank_is_determined)
if (shape.m_rank_is_static)
{
str << "{";
bool first = true;
......@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
bool PartialShape::compatible(const PartialShape& s) const
{
// If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined())
if (!m_rank_is_static || s.rank().is_dynamic())
{
return true;
}
......@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const
}
}
bool PartialShape::same_scheme(const PartialShape& s) const
{
if (rank().is_dynamic() && s.rank().is_dynamic())
{
return true;
}
else if (rank().is_static() && s.rank().is_static())
{
if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
bool success = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
success &= (*this)[i].same_scheme(s[i]);
}
return success;
}
else
{
return false;
}
}
Shape PartialShape::to_shape() const
{
if (!is_complete())
if (is_dynamic())
{
throw std::invalid_argument("to_shape was called on an incomplete shape.");
throw std::invalid_argument("to_shape was called on a dynamic shape.");
}
return Shape(m_dimensions.begin(), m_dimensions.end());
}
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
{
if (dst.rank().is_dynamic())
{
dst = src;
return true;
}
else if (src.rank().is_dynamic())
{
// No change to dst.
return true;
}
else if (size_t(dst.rank()) != size_t(src.rank()))
{
// Mismatching static ranks, cannot merge.
return false;
}
else
{
// Ranks are both static, and they match.
bool success = true;
for (size_t i = 0; i < size_t(dst.rank()); i++)
{
success &= Dimension::merge(dst[i], dst[i], src[i]);
}
return success;
}
}
......@@ -24,93 +24,200 @@
namespace ngraph
{
/// \brief Class representing a shape that may only be partially known.
/// \brief Class representing a shape that may be partially or totally dynamic.
///
/// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
///
/// A partially-known shape may have:
/// A PartialShape may have:
///
/// - Unknown rank.
/// - Known rank, but unknown dimensions on some or all axes.
/// - Known rank, and known dimensions on all axes.
/// \li Dynamic rank. (Informal notation: `?`)
/// \li Static rank, but dynamic dimensions on some or all axes.
/// (Informal notation examples: `{1,2,?,4}`, `{?,?,?}`)
/// \li Static rank, and dynamic dimensions on all axes.
/// (Informal notation examples: `{1,2,3,4}`, `{6}`, `{}`)
class PartialShape
{
public:
/// \brief Constructs a shape with determined rank.
/// \brief Constructs a shape with static rank from an initializer list of Dimension.
/// \param init The Dimension values for the constructed shape.
///
/// Examples:
///
/// PartialShape s{2,3,4}; // rank=3, all dimensions determined
/// \code{.cpp}
/// PartialShape s{2,3,4}; // rank=3, all dimensions static
/// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::undetermined(),3}; // rank=2, dimension 1 undetermined
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=2, dimension 1 dynamic
/// \endcode
PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init)
{
}
/// \brief Constructs a complete PartialShape from a Shape.
PartialShape(const Shape& shape);
/// \brief Returns true if the shape has determined rank.
bool rank_is_determined() const { return m_rank_is_determined; }
/// \brief Returns true if the shape has known rank and all dimensions of the shape
/// are determined.
bool is_complete() const;
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
{
}
/// \brief Returns the rank of the shape. Returns Rank::undetermined() if the rank is undetermined.
Rank rank() const
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
PartialShape()
: PartialShape({})
{
return m_rank_is_determined ? Rank(m_dimensions.size()) : Rank::undetermined();
}
/// \brief Appends another shape to this shape.
/// \brief Constructs a static PartialShape from a Shape.
/// \param shape The Shape to convert into PartialShape.
PartialShape(const Shape& shape);
/// \brief Check if this shape is static.
/// \return `true` if this shape is static, else `false`.
///
/// If "this" and "other" both have determined rank, returns a new shape two shape
/// whose dimensions are the concatenation of the dimensions of "this" and "other".
/// If either "this" or "other" has undetermined rank, returns
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other);
/// A shape is considered static if it has static rank, and all dimensions of the shape
/// are static.
bool is_static() const;
/// \brief Returns the undetermined shape.
static PartialShape undetermined() { return PartialShape(false, {}); }
/// \brief Returns true if *this is compatible with s.
/// \brief Check if this shape is dynamic.
/// \return `false` if this shape is static, else `true`.
///
/// A shape is considered static if it has static rank, and all dimensions of the shape
/// are static.
bool is_dynamic() const { return !is_static(); }
/// \brief Get the rank of the shape.
/// \return The rank of the shape. This will be Rank::dynamic() if the rank of
/// the shape is dynamic.
Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); }
/// \brief Construct a PartialShape with dynamic rank.
/// \return A PartialShape with dynamic rank.
static PartialShape dynamic() { return PartialShape(false, {}); }
/// \brief Check whether this shape is compatible with the argument, i.e., whether it is
/// possible to merge them.
/// \param s The shape to be checked for compatibility with this shape.
/// \return `true` if this shape is compatible with `s`, else `false`.
///
/// Two dimensions are compatible if one or both of them is undetermined, or if
/// they are both determined and equal.
/// Two shapes are compatible if
/// \li one or both of them has dynamic rank, or
/// \li both shapes have dynamic and equal rank, and their dimensions are elementwise
/// compatible (see Dimension::compatible()).
bool compatible(const PartialShape& s) const;
/// \brief Converts a complete PartialShape to a Shape.
/// \brief Check whether this shape represents the same scheme as the argument.
/// \param s The shape whose scheme is being compared with this shape.
/// \return `true` if this shape represents the same scheme as `s`, else `false`.
///
/// Throws std::invalid_argument if the PartialShape is incomplete.
/// Two shapes `s1` and `s2` represent the same scheme if
/// \li they both have dynamic rank, or
/// \li they both have static and equal rank `r`, and for every `i` from `0` to `r-1`,
/// `s1[i]` represents the same scheme as `s2[i]` (see Dimension::same_scheme()).
bool same_scheme(const PartialShape& s) const;
/// \brief Convert a static PartialShape to a Shape.
/// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
/// \throws std::invalid_argument If this PartialShape is dynamic.
Shape to_shape() const;
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
const Dimension& operator[](size_t i) const { return m_dimensions[i]; }
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
Dimension& operator[](size_t i) { return m_dimensions[i]; }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Try to merge one shape into another.
/// \param[in,out] dst The shape that `src` will be merged into.
/// \param src The shape that will be merged into `dst`.
/// \return `true` if merging succeeds, else `false`.
///
/// Merges `src` into `dst`, returning `true` on success and `false` on failure. If
/// `false` is returned, the effect on `dst` is unspecified.
///
/// To merge two partial shapes `s1` and `s2` is to find the most permissive partial shape
/// `s` that is no more permissive than `s1` or `s2`, if `s` exists. For example:
///
/// \code
/// merge(?,?) -> ?
/// merge(?,{?,?}) -> {?,?}
/// merge({?,?},{?,?}) -> {?,?}
/// merge({1,2,3,4},?) -> {1,2,3,4}
/// merge({1,2},{1,?}) -> {1,2}
/// merge({1,2,?,?},{1,?,3,?}) -> {1,2,3,?}
/// merge({1,2,3},{1,2,3}) -> {1,2,3}
///
/// merge({1,?},{2,?}) fails [dimension 0 constraints are inconsistent]
/// merge({?,?},{?,?,?}) fails [ranks are inconsistent]
/// \endcode
///
/// This function (merge_into) performs the "merge" operation described above on `dst` and
/// `src`, but overwrites `dst` with the result and returns `true` if merging is
/// successful; if merging is unsuccessful, the function returns `false` and may make
/// unspecified changes to `dst`.
static bool merge_into(PartialShape& dst, const PartialShape& src);
private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape.
PartialShape(bool rank_is_determined, std::initializer_list<Dimension> init)
: m_rank_is_determined(rank_is_determined)
// Private constructor so PartialShape::dynamic() can construct a shape with
// m_rank_is_static set to false.
PartialShape(bool rank_is_static, std::initializer_list<Dimension> init)
: m_rank_is_static(rank_is_static)
, m_dimensions(init)
{
}
// True if the shape's rank is determined.
bool m_rank_is_determined;
// True if the shape's rank is static.
bool m_rank_is_static;
// Shape dimensions. This has no meaning if m_rank_is_determined is false.
// Shape dimensions. This has no meaning if m_rank_is_static is false.
std::vector<Dimension> m_dimensions;
};
/// \brief Elementwise addition of two shapes.
/// \brief Elementwise addition of two PartialShape objects.
/// \param s1 Left operand for addition.
/// \param s2 Right operand for addition.
/// \return The result of elementwise adding `s1` to `s2` (see description).
/// \throws std::invalid_argument If `s1` and `s2` have inconsistent ranks.
///
/// If s1 or s2 has undetermined rank, returns PartialShape::undetermined().
/// If s1 and s2 both have determined rank, and their ranks are unequal,
/// throws std::invalid_argument.
/// If s1 and s2 both have determined rank, and their ranks are equal,
/// returns a new shape whose ith dimension is s1[i] + s2[i].
/// \li If `s1` or `s2` has dynamic rank, returns PartialShape::dynamic().
/// \li If `s1 and `s2` both have static rank, and their ranks are unequal, throws
/// std::invalid_argument.
/// \li If `s1` and `s2` both have static rank, and their ranks are equal,
/// returns a new shape whose `i`th dimension is `s1[i] + s2[i]`.
PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Inserts a human-readable representation of "shape" into "str".
/// \brief Inserts a human-readable representation of a PartialShape into an output stream.
/// \param str The output stream targeted for insertion.
/// \param shape The shape to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// The output to the stream is in "informal" notation. In other words:
///
/// \li If `shape` has dynamic rank, inserts the string `?`.
/// \li If `shape` has static rank, inserts the string `{`, then inserts each dimension
/// of `shape` into the output stream separated by commas, then inserts `}`.
///
/// Example:
///
/// \code{.cpp}
/// PartialShape s1{PartialShape::dynamic())};
/// PartialShape s2{};
/// PartialShape s3{1,Dimension::dynamic(),2,3};
/// PartialShape s4{2,3,4};
/// std::cout << s1 << std::endl
/// << s2 << std::endl
/// << s3 << std::endl
/// << s4 << std::endl;
/// \endcode
///
/// Output:
///
/// \code
/// ?
/// {}
/// {1,?,2,3}
/// {2,3,4}
/// \endcode
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
}
......@@ -20,9 +20,8 @@
namespace ngraph
{
/// \brief Alias for "Dimension". Should be used to when the value represents the number of
/// axes in a shape-like object, rather than the size of one dimension in a shape-like
/// object.
/// \brief Alias for Dimension, used when the value represents the number of axes in a shape,
/// rather than the size of one dimension in a shape.
///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension;
......
......@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data);
static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
static json write_dimension(Dimension d)
{
if (d.is_static())
{
return size_t(d);
}
else
{
return nullptr;
}
}
static json write_partial_shape(const PartialShape& s)
{
if (s.rank().is_dynamic())
{
return nullptr;
}
else
{
std::vector<json> vals(size_t(s.rank()));
for (size_t i = 0; i < vals.size(); i++)
{
vals[i] = write_dimension(s[i]);
}
return vals;
}
}
static PartialShape read_partial_shape(const json& j)
{
if (j.is_null())
{
return PartialShape::dynamic();
}
else
{
std::vector<Dimension> dims(j.size());
for (size_t i = 0; i < j.size(); i++)
{
if (j[i].is_null())
{
dims[i] = Dimension::dynamic();
}
else
{
dims[i] = size_t(j[i]);
}
}
return PartialShape(dims);
}
}
static json write_element_type(const ngraph::element::Type& n)
{
json j;
......@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable);
node =
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break;
}
case OP_TYPEID::Power:
......@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Parameter:
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape();
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type());
break;
......
......@@ -21,7 +21,7 @@
using namespace ngraph;
const element::Type element::unspecified(0, false, false, false, "unspecified");
const element::Type element::dynamic(0, false, false, false, "dynamic");
const element::Type element::boolean(8, false, true, false, "char");
const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double");
......@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
<< ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}";
return out;
}
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{
if (t1.is_dynamic())
{
dst = t2;
return true;
}
else if (t2.is_dynamic())
{
dst = t1;
return true;
}
else if (t1 == t2)
{
dst = t1;
return true;
}
else
{
return false;
}
}
......@@ -33,7 +33,7 @@ namespace ngraph
{
class Type;
extern const Type unspecified;
extern const Type dynamic;
extern const Type boolean;
extern const Type f32;
extern const Type f64;
......@@ -61,6 +61,8 @@ namespace ngraph
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const;
bool is_static() const { return (*this != dynamic); }
bool is_dynamic() const { return !is_static(); }
bool is_real() const { return m_is_real; }
bool is_signed() const { return m_is_signed; }
bool is_quantized() const { return m_is_quantized; }
......@@ -73,12 +75,32 @@ namespace ngraph
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
///
/// To "merge" two element types t1 and t2 is to find the least restrictive
/// element type t that is no more restrictive than t1 and t2, if t exists.
/// More simply:
///
/// merge(dst,element::Type::dynamic,t)
/// writes t to dst and returns true
///
/// merge(dst,t,element::Type::dynamic)
/// writes t to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and equal
/// writes t1 to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and unequal
/// does nothing to dst, and returns false
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
private:
size_t m_bitwidth{0};
bool m_is_real{false};
bool m_is_signed{false};
bool m_is_quantized{false};
std::string m_cname{"unspecified"};
std::string m_cname{"dynamic"};
};
template <typename T>
......
......@@ -84,3 +84,42 @@ TEST(element_type, size)
EXPECT_EQ(2, t1.size());
}
}
TEST(element_type, merge_both_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::dynamic));
ASSERT_TRUE(t.is_dynamic());
}
TEST(element_type, merge_left_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::u64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::u64);
}
TEST(element_type, merge_right_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::i16, element::dynamic));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::i16);
}
TEST(element_type, merge_both_static_equal)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::f64, element::f64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f64);
}
TEST(element_type, merge_both_static_unequal)
{
element::Type t = element::f32;
ASSERT_FALSE(element::Type::merge(t, element::i8, element::i16));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f32);
}
......@@ -24,110 +24,106 @@ using namespace ngraph;
TEST(partial_shape, ps_construction_empty)
{
auto ps = PartialShape{};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.is_static());
ASSERT_EQ(size_t(ps.rank()), 0);
}
TEST(partial_shape, ps_construction_undetermined)
TEST(partial_shape, ps_construction_rank_dynamic)
{
auto ps = PartialShape::undetermined();
ASSERT_FALSE(ps.rank_is_determined());
ASSERT_FALSE(ps.rank().is_determined());
ASSERT_FALSE(ps.is_complete());
auto ps = PartialShape::dynamic();
ASSERT_TRUE(ps.rank().is_dynamic());
ASSERT_TRUE(ps.is_dynamic());
}
TEST(partial_shape, ps_construction_incomplete)
TEST(partial_shape, ps_construction_rank_static_shape_dynamic)
{
auto ps = PartialShape{2, Dimension::undetermined(), 3};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_FALSE(ps.is_complete());
auto ps = PartialShape{2, Dimension::dynamic(), 3};
ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.is_dynamic());
ASSERT_EQ(size_t(ps.rank()), 3);
}
TEST(partial_shape, ps_construction_complete)
TEST(partial_shape, ps_construction_static)
{
auto ps = PartialShape{2, 5, 3, 6};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.is_static());
ASSERT_EQ(size_t(ps.rank()), 4);
}
TEST(partial_shape, dim_construction_determined)
TEST(partial_shape, dim_construction_static)
{
Dimension dim{3};
ASSERT_EQ(size_t(dim), 3);
ASSERT_TRUE(dim.is_determined());
ASSERT_TRUE(dim.is_static());
}
TEST(partial_shape, dim_construction_undetermined)
TEST(partial_shape, dim_construction_dynamic)
{
Dimension dim = Dimension::undetermined();
ASSERT_FALSE(dim.is_determined());
Dimension dim = Dimension::dynamic();
ASSERT_TRUE(dim.is_dynamic());
}
TEST(partial_shape, dim_construction_size_t_max)
{
EXPECT_ANY_THROW({ Dimension d{Dimension::s_undetermined_val}; });
EXPECT_ANY_THROW({ Dimension d{Dimension::s_dynamic_val}; });
}
TEST(partial_shape, dim_conversion_determined)
TEST(partial_shape, dim_conversion_static)
{
Dimension d{42};
size_t s{d};
ASSERT_EQ(s, 42);
}
TEST(partial_shape, dim_conversion_undetermined)
TEST(partial_shape, dim_conversion_dynamic)
{
EXPECT_ANY_THROW({
size_t s{Dimension::undetermined()};
size_t s{Dimension::dynamic()};
s = 0; // Silence compiler warning about unused s
});
}
TEST(partial_shape, rank_construction_determined)
TEST(partial_shape, rank_construction_static)
{
Rank r{4};
ASSERT_EQ(size_t(r), 4);
ASSERT_TRUE(r.is_determined());
ASSERT_TRUE(r.is_static());
}
TEST(partial_shape, rank_construction_undetermined)
TEST(partial_shape, rank_construction_dynamic)
{
Rank r = Rank::undetermined();
ASSERT_FALSE(r.is_determined());
Rank r = Rank::dynamic();
ASSERT_TRUE(r.is_dynamic());
}
TEST(partial_shape, dim_compatible_left_undetermined)
TEST(partial_shape, dim_compatible_left_dynamic)
{
Dimension d1{Dimension::undetermined()};
Dimension d1{Dimension::dynamic()};
Dimension d2{3};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_right_undetermined)
TEST(partial_shape, dim_compatible_right_dynamic)
{
Dimension d1{3};
Dimension d2{Dimension::undetermined()};
Dimension d2{Dimension::dynamic()};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_both_undetermined)
TEST(partial_shape, dim_compatible_both_dynamic)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{Dimension::undetermined()};
Dimension d1{Dimension::dynamic()};
Dimension d2{Dimension::dynamic()};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_both_determined)
TEST(partial_shape, dim_compatible_both_static)
{
Dimension d1{3};
Dimension d2{8};
......@@ -137,25 +133,25 @@ TEST(partial_shape, dim_compatible_both_determined)
ASSERT_TRUE(d1.compatible(d3));
}
TEST(partial_shape, shapes_compatible_both_rank_undetermined)
TEST(partial_shape, shapes_compatible_both_rank_dynamic)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{PartialShape::undetermined()};
PartialShape ps1{PartialShape::dynamic()};
PartialShape ps2{PartialShape::dynamic()};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_left_rank_undetermined)
TEST(partial_shape, shapes_compatible_left_rank_dynamic)
{
PartialShape ps1{3};
PartialShape ps2{PartialShape::undetermined()};
PartialShape ps2{PartialShape::dynamic()};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_right_rank_undetermined)
TEST(partial_shape, shapes_compatible_right_rank_dynamic)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps1{PartialShape::dynamic()};
PartialShape ps2{4};
ASSERT_TRUE(ps1.compatible(ps2));
......@@ -163,21 +159,21 @@ TEST(partial_shape, shapes_compatible_right_rank_undetermined)
TEST(partial_shape, shapes_compatible_both_partial_all_known_equal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{2, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
PartialShape ps1{2, Dimension::dynamic(), 3, Dimension::dynamic(), 5};
PartialShape ps2{2, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_both_partial_some_known_unequal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{1, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
PartialShape ps1{2, Dimension::dynamic(), 3, Dimension::dynamic(), 5};
PartialShape ps2{1, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_both_complete_different_rank)
TEST(partial_shape, shapes_compatible_both_static_different_rank)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8, 10};
......@@ -185,7 +181,7 @@ TEST(partial_shape, shapes_compatible_both_complete_different_rank)
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
TEST(partial_shape, shapes_equal_both_static_same_rank_same_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8};
......@@ -193,7 +189,7 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims)
TEST(partial_shape, shapes_equal_both_static_same_rank_different_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 3, 8};
......@@ -206,13 +202,16 @@ TEST(partial_shape, from_shape)
Shape s{2, 4, 6, 8};
PartialShape ps1{s};
// TODO(amprocte): No way to examine contents of ps1 yet.
ASSERT_TRUE(ps1.is_complete());
ASSERT_TRUE(ps1.rank_is_determined());
ASSERT_TRUE(ps1.rank().is_static());
ASSERT_EQ(size_t(ps1.rank()), s.size());
ASSERT_TRUE(ps1.is_static());
ASSERT_EQ(size_t(ps1[0]), 2);
ASSERT_EQ(size_t(ps1[1]), 4);
ASSERT_EQ(size_t(ps1[2]), 6);
ASSERT_EQ(size_t(ps1[3]), 8);
}
TEST(partial_shape, to_shape_complete)
TEST(partial_shape, to_shape_static)
{
PartialShape ps{2, 4, 6, 8};
Shape s{ps.to_shape()};
......@@ -220,15 +219,15 @@ TEST(partial_shape, to_shape_complete)
ASSERT_EQ(s, (Shape{2, 4, 6, 8}));
}
TEST(partial_shape, to_shape_dims_undetermined)
TEST(partial_shape, to_shape_dims_dynamic)
{
PartialShape ps{2, 4, Dimension::undetermined(), 8};
PartialShape ps{2, 4, Dimension::dynamic(), 8};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
}
TEST(partial_shape, to_shape_rank_undetermined)
TEST(partial_shape, to_shape_rank_dynamic)
{
PartialShape ps{PartialShape::undetermined()};
PartialShape ps{PartialShape::dynamic()};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
}
......@@ -238,28 +237,250 @@ TEST(partial_shape, tensor_descriptor_from_shape)
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, tensor_descriptor_from_complete_partial_shape)
TEST(partial_shape, tensor_descriptor_from_static_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape{1, 2, 3}, "Burnside"};
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, tensor_descriptor_from_incomplete_partial_shape)
TEST(partial_shape, tensor_descriptor_from_rank_static_dynamic_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape{1, Dimension::undetermined(), 3}, "Couch"};
descriptor::Tensor t{element::i32, PartialShape{1, Dimension::dynamic(), 3}, "Couch"};
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(partial_shape, tensor_descriptor_from_rankless_partial_shape)
TEST(partial_shape, tensor_descriptor_from_rank_dynamic_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape::undetermined(), "Davis"};
descriptor::Tensor t{element::i32, PartialShape::dynamic(), "Davis"};
ASSERT_FALSE(t.get_partial_shape().rank().is_determined());
ASSERT_TRUE(t.get_partial_shape().rank().is_dynamic());
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape::dynamic()));
}
TEST(partial_shape, dim_same_scheme_both_dynamic)
{
ASSERT_TRUE(Dimension::dynamic().same_scheme(Dimension::dynamic()));
}
TEST(partial_shape, dim_same_scheme_left_dynamic)
{
ASSERT_FALSE(Dimension::dynamic().same_scheme(6));
}
TEST(partial_shape, dim_same_scheme_right_dynamic)
{
ASSERT_FALSE(Dimension(6).same_scheme(Dimension::dynamic()));
}
TEST(partial_shape, dim_same_scheme_both_static_same)
{
ASSERT_TRUE(Dimension(6).same_scheme(Dimension(6)));
}
TEST(partial_shape, dim_same_scheme_both_static_different)
{
ASSERT_FALSE(Dimension(6).same_scheme(Dimension(7)));
}
TEST(partial_shape, partial_shape_same_scheme_both_dynamic)
{
ASSERT_TRUE(PartialShape::dynamic().same_scheme(PartialShape::dynamic()));
}
TEST(partial_shape, partial_shape_same_scheme_left_dynamic_right_rank_static_dynamic)
{
ASSERT_FALSE(PartialShape::dynamic().same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(partial_shape, partial_shape_same_scheme_left_dynamic_right_static)
{
ASSERT_FALSE(PartialShape::dynamic().same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_same_scheme_right_dynamic_left_rank_static_dynamic)
{
ASSERT_FALSE((PartialShape{1, Dimension::dynamic(), 3}.same_scheme(PartialShape::dynamic())));
}
TEST(partial_shape, partial_shape_same_scheme_right_dynamic_left_static)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape::dynamic())));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_different_rank)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 2, 3, 4})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_different_rank)
{
ASSERT_FALSE((PartialShape{1, Dimension::dynamic(), 3}.same_scheme(
PartialShape{1, Dimension::dynamic(), 3, 4})));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_same_rank_different_dims)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 3, 3})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_same_rank_different_dims)
{
ASSERT_FALSE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, 3, Dimension::dynamic()})));
}
TEST(partial_shape,
partial_shape_same_scheme_both_rank_static_dynamic_same_rank_compatible_not_same)
{
ASSERT_FALSE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, Dimension::dynamic(), 3})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_same_rank_compatible_same)
{
ASSERT_TRUE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, 2, Dimension::dynamic()})));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_same_rank_same_dims)
{
ASSERT_TRUE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 2, 3})));
}
TEST(partial_shape, partial_shape_same_scheme_scalar)
{
ASSERT_TRUE((PartialShape{}.same_scheme(PartialShape{})));
}
TEST(partial_shape, dim_merge_both_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, Dimension::dynamic(), Dimension::dynamic()));
ASSERT_TRUE(d.is_dynamic());
}
TEST(partial_shape, dim_merge_left_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, Dimension::dynamic(), 3));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_right_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, 3, Dimension::dynamic()));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_both_static_equal)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, 3, 3));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_both_static_unequal)
{
Dimension d = 163;
ASSERT_FALSE(Dimension::merge(d, 3, 4));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 163);
}
TEST(partial_shape, partial_shape_merge_both_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.rank().is_dynamic());
}
TEST(partial_shape, partial_shape_merge_left_rank_dynamic_right_rank_static_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_left_rank_dynamic_right_static)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{1, 2, 3};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_left_rank_static_dynamic_right_rank_dynamic)
{
PartialShape s1{1, 2, Dimension::dynamic()};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_left_static_right_rank_dynamic)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_consistent)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic(), Dimension::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_same_rank_inconsistent)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{2, 2, Dimension::dynamic(), Dimension::dynamic()};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_different_rank)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic()};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_static_consistent)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 3};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_both_static_inconsistent)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 4};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_static_different_rank)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 3, 4};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
......@@ -476,9 +476,7 @@ void test_binary(std::string node_type,
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument 0 shape Shape{2, 4} differs in shape from argument 1"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......@@ -497,10 +495,8 @@ void test_binary(std::string node_type,
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument 0 element type element::Type{32, 1, "
"1, 0, \"float\"} differs in element type from argument 1"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent"));
}
catch (...)
{
......@@ -572,9 +568,7 @@ void test_binary_logical(std::string node_type,
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument 0 shape Shape{2, 4} differs in shape from argument 1"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
}
catch (...)
{
......@@ -593,10 +587,8 @@ void test_binary_logical(std::string node_type,
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Argument 0 element type element::Type{8, 0, 1, 0, \"char\"} "
"differs in element type from argument 1"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Argument element types are inconsistent"));
}
catch (...)
{
......@@ -624,7 +616,7 @@ void test_binary_logical(std::string node_type,
};
test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
test_binary_differ_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
......@@ -6598,6 +6590,384 @@ TEST(type_prop, topk_invalid_k)
}
}
TEST(type_prop, param_partial_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto& pshape = a->get_output_partial_shape(0);
ASSERT_TRUE(pshape.is_dynamic());
ASSERT_TRUE(pshape.rank().is_dynamic());
}
TEST(type_prop, param_partial_rank_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3, 4});
auto& pshape = a->get_output_partial_shape(0);
ASSERT_TRUE(pshape.is_dynamic());
ASSERT_EQ(size_t(pshape.rank()), 4);
ASSERT_TRUE(pshape[0].is_static() && size_t(pshape[0]) == 2);
ASSERT_TRUE(pshape[1].is_dynamic());
ASSERT_TRUE(pshape[2].is_static() && size_t(pshape[2]) == 3);
ASSERT_TRUE(pshape[3].is_static() && size_t(pshape[3]) == 4);
}
TEST(type_prop, binary_elementwise_arithmetic_both_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(type_prop,
binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(
type_prop,
binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(
element::f32, PartialShape{1, Dimension::dynamic(), Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3, 4});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_element_type(0).is_dynamic());
}
TEST(type_prop, binary_elementwise_arithmetic_left_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::u32, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_EQ(add->get_output_element_type(0), element::u32);
}
TEST(type_prop, binary_elementwise_arithmetic_right_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_EQ(add->get_output_element_type(0), element::i64);
}
TEST(type_prop, logic_arith_compare_partial_et)
{
auto test_logic = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::And>(param0, param1);
};
auto test_arith = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::Add>(param0, param1);
};
auto test_compare = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::Greater>(param0, param1);
};
auto test_not = [](element::Type et) -> std::shared_ptr<Node> {
auto param = std::make_shared<op::Parameter>(et, Shape{1, 2, 3});
return std::make_shared<op::Not>(param);
};
// Logical ops:
//
// int int -> !
// int boo -> !
// int dyn -> !
// boo int -> !
// boo boo -> boo
// boo dyn -> boo
// dyn int -> !
// dyn boo -> boo
// dyn dyn -> boo
ASSERT_ANY_THROW({ test_logic(element::i32, element::i32); });
ASSERT_ANY_THROW({ test_logic(element::i32, element::boolean); });
ASSERT_ANY_THROW({ test_logic(element::i32, element::dynamic); });
ASSERT_ANY_THROW({ test_logic(element::boolean, element::i32); });
ASSERT_EQ(test_logic(element::boolean, element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_logic(element::boolean, element::dynamic)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_logic(element::dynamic, element::i32); });
ASSERT_EQ(test_logic(element::dynamic, element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_logic(element::dynamic, element::dynamic)->get_element_type(), element::boolean);
// Arith ops:
//
// int int -> int
// int boo -> !
// int dyn -> int
// boo int -> !
// boo boo -> !
// boo dyn -> !
// dyn int -> int
// dyn boo -> !
// dyn dyn -> dyn
ASSERT_EQ(test_arith(element::i32, element::i32)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::i32, element::boolean); });
ASSERT_EQ(test_arith(element::i32, element::dynamic)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::boolean, element::i32); });
ASSERT_ANY_THROW({ test_arith(element::boolean, element::boolean); });
ASSERT_ANY_THROW({ test_arith(element::boolean, element::dynamic); });
ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); });
ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic);
// Comparison ops:
//
// int int -> boo
// int boo -> !
// int dyn -> boo
// boo int -> !
// boo boo -> boo
// boo dyn -> boo
// dyn int -> boo
// dyn boo -> boo
// dyn dyn -> boo
ASSERT_EQ(test_compare(element::i32, element::i32)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_compare(element::i32, element::boolean); });
ASSERT_EQ(test_compare(element::i32, element::dynamic)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_compare(element::boolean, element::i32); });
ASSERT_EQ(test_compare(element::boolean, element::boolean)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::boolean, element::dynamic)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::i32)->get_element_type(), element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::boolean)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::dynamic)->get_element_type(),
element::boolean);
// Logical negation op:
//
// Current behavior:
// int -> int
// boo -> boo
// dyn -> dyn
//
// TODO(amprocte): I believe the behavior should actually be:
// int -> !
// boo -> boo
// dyn -> boo
ASSERT_EQ(test_not(element::i32)->get_element_type(), element::i32);
ASSERT_EQ(test_not(element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_not(element::dynamic)->get_element_type(), element::dynamic);
}
TEST(type_prop, quantize_f32_to_i8_nchw_per_channel_ok)
{
Shape batch_shape{64, 3, 480, 640};
......@@ -7543,3 +7913,20 @@ TEST(type_prop, dequantize_offset_shape_mismatch_different_rank_fails)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
//
// This is testing a temporary hack for ops that do not yet support partial-shape validation.
// The graph we construct here is bogus, but because there is some partiality in the input shapes,
// it should still pass validation but set the output shape and element types to be dynamic.
//
TEST(type_prop, validate_punt_if_dynamic)
{
auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::u32, PartialShape{1, Dimension::dynamic(), 3});
auto c = make_shared<op::Parameter>(element::i32, Shape{1, 8, 3});
auto concat = make_shared<op::Concat>(NodeVector{a, b, c}, /*concatenation axis=*/1234);
ASSERT_EQ(concat->get_output_size(), 1);
ASSERT_TRUE(concat->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(concat->get_output_element_type(0).is_dynamic());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment