Commit 0563a3cf authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 3: Framework for partial shape/element type validation (#1728)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item
parent 631d7253
...@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const std::string& name) const std::string& name)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(pshape.is_complete() ? pshape.to_shape() : Shape{}) , m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape) , m_partial_shape(pshape)
, m_name(name) , m_name(name)
{ {
...@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape) const PartialShape& pshape)
{ {
if (pshape.is_complete()) if (pshape.is_static())
{ {
m_shape = pshape.to_shape(); m_shape = pshape.to_shape();
} }
...@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type, ...@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const Shape& descriptor::Tensor::get_shape() const const Shape& descriptor::Tensor::get_shape() const
{ {
if (m_partial_shape.is_complete()) if (m_partial_shape.is_static())
{ {
return m_shape; return m_shape;
} }
else else
{ {
throw std::invalid_argument( throw std::invalid_argument(
"get_shape was called on a descriptor::Tensor with incomplete shape"); "get_shape was called on a descriptor::Tensor with dynamic shape");
} }
} }
......
...@@ -25,11 +25,11 @@ using namespace ngraph; ...@@ -25,11 +25,11 @@ using namespace ngraph;
Dimension::Dimension(size_t dimension) Dimension::Dimension(size_t dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
if (dimension == s_undetermined_val) if (dimension == s_dynamic_val)
{ {
std::stringstream ss; std::stringstream ss;
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_undetermined_val ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_dynamic_val
<< " to Dimension: this value is used internally to represent an undetermined " << " to Dimension: this value is used internally to represent a dynamic "
"dimension."; "dimension.";
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
...@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension) ...@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension)
std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{ {
if (dimension.is_determined()) if (dimension.is_static())
{ {
return (str << size_t(dimension)); return (str << size_t(dimension));
} }
...@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2) Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
{ {
return (d1.is_determined() && d2.is_determined() ? size_t(d1) + size_t(d2) return (d1.is_static() && d2.is_static() ? size_t(d1) + size_t(d2) : Dimension::dynamic());
: Dimension::undetermined());
} }
bool Dimension::compatible(const Dimension& d) const bool Dimension::compatible(const Dimension& d) const
{ {
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d)); return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.is_dynamic())
{
dst = d2;
return true;
}
else if (d2.is_dynamic())
{
dst = d1;
return true;
}
else if (size_t(d1) != size_t(d2))
{
return false;
}
else
{
dst = d1;
return true;
}
} }
...@@ -22,56 +22,97 @@ ...@@ -22,56 +22,97 @@
namespace ngraph namespace ngraph
{ {
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object. /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
/// ///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is /// Static dimensions may be implicitly converted from size_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::undetermined(). /// constructed with Dimension() or Dimension::dynamic().
/// ///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class Dimension class Dimension
{ {
public: public:
/// \brief Constructs a known dimension. /// \brief Construct a static dimension.
/// /// \param dimension Value of the dimension. Must not be equal to
/// Requires that dimension != s_undetermined_val. If that condition does not hold, /// Dimension::s_dynamic_val.
/// throws std::invalid_argument. /// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension); Dimension(size_t dimension);
/// \brief Constructs an unknown dimension. /// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_undetermined_val; } Dimension() { m_dimension = s_dynamic_val; }
/// \brief Returns true if this dimension is determined. /// \brief Check whether this dimension is static.
bool is_determined() const { return m_dimension != s_undetermined_val; } /// \return `true` if the dimension is static, else `false`.
/// \brief Converts this dimension to size_t. If the dimension is undetermined, throws bool is_static() const { return m_dimension != s_dynamic_val; }
/// std::invalid_argument. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator size_t() const explicit operator size_t() const
{ {
if (!is_determined()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert unknown dimension to size_t"); throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
} }
return m_dimension; return m_dimension;
} }
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions /// \brief Check whether this dimension represents the same scheme as the argument (both
/// is undetermined, or both dimensions are determined and equal. /// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const
{
return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim));
}
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const; bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension. /// \brief Create a dynamic dimension.
static Dimension undetermined() { return Dimension(); } /// \return A dynamic dimension.
/// \brief Constant for the value used internally to represent an unknown dimension. static Dimension dynamic() { return Dimension(); }
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()}; /// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{std::numeric_limits<size_t>::max()};
private: private:
// The actual numerical value of the dimension. s_undetermined_val is a special case, // The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing an unknown dimension. // representing a dynamic dimension.
size_t m_dimension; size_t m_dimension;
}; };
/// \brief Inserts a human-readable representation of "dimension" into "str". /// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension); std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions. /// \brief Addition operator for dimensions.
/// /// \param d1 Left operand for addition.
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns /// \param d2 Right operand for addition.
/// Dimension::undetermined(). /// \return Dimension::dynamic() if either of `d1` or `d2` is dynamic; else, a static
/// dimension with value `size_t(d1)+size_t(d2)`.
Dimension operator+(const Dimension& d1, const Dimension& d2); Dimension operator+(const Dimension& d1, const Dimension& d2);
} }
...@@ -52,9 +52,7 @@ namespace ngraph ...@@ -52,9 +52,7 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break; case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break; case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break; case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type"; default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
} }
......
...@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n) ...@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n)
for (size_t i = m_outputs.size(); i < n; ++i) for (size_t i = m_outputs.size(); i < n; ++i)
{ {
auto tensor_descriptor = make_shared<descriptor::Tensor>( auto tensor_descriptor = make_shared<descriptor::Tensor>(
element::unspecified, Shape(), get_name() + "_" + to_string(i)); element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_descriptor); m_outputs.emplace_back(this, i, tensor_descriptor);
} }
} }
...@@ -84,9 +84,9 @@ void Node::validate_and_infer_types() ...@@ -84,9 +84,9 @@ void Node::validate_and_infer_types()
{ {
} }
void Node::set_output_type(size_t i, const element::Type& element_type, const Shape& shape) void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{ {
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, shape); m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
} }
std::deque<descriptor::Output>& Node::get_outputs() std::deque<descriptor::Output>& Node::get_outputs()
...@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const ...@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const
return out << get_name(); return out << get_name();
} }
static std::string pretty_element_type(const element::Type& et)
{
if (et.is_dynamic())
{
return "?";
}
else
{
return et.c_type_string();
}
}
std::ostream& Node::write_long_description(std::ostream& out) const std::ostream& Node::write_long_description(std::ostream& out) const
{ {
out << description() << '[' << get_name() << "]("; out << description() << '[' << get_name() << "](";
string sep = ""; string sep = "";
for (auto arg : get_arguments()) for (auto arg : get_arguments())
{ {
out << sep << NodeDescription(*arg, true); out << sep << NodeDescription(*arg, true) << ": "
<< pretty_element_type(arg->get_output_element_type(0))
<< arg->get_output_partial_shape(0) << "";
sep = ", "; sep = ", ";
} }
out << ")"; out << ")";
...@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args) ...@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args; return args;
} }
void Node::validate_and_infer_elementwise(element::Type result_type) std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_args()
{ {
const element::Type& element_type = get_input_element_type(0); element::Type element_type = get_input_element_type(0);
const Shape& shape = get_input_shape(0); PartialShape pshape = get_input_partial_shape(0);
if (get_input_size() > 1) if (get_input_size() > 1)
{ {
for (size_t i = 1; i < get_input_size(); ++i) for (size_t i = 1; i < get_input_size(); ++i)
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == element_type) NODE_VALIDATION_ASSERT(
<< "Argument 0 element type " << element_type this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< " differs in element type from argument " << i << " " << *get_argument(i) << "Argument element types are inconsistent.";
<< " element type " << get_input_element_type(i);
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == shape) NODE_VALIDATION_ASSERT(this,
<< "Argument 0 shape " << shape << " differs in shape from argument " << i << " " PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< *get_argument(i) << " shape " << get_input_shape(i); << "Argument shapes are inconsistent.";
} }
} }
set_output_type(0, result_type, shape);
return std::make_tuple(element_type, pshape);
} }
void Node::validate_and_infer_elementwise_arithmetic() void Node::validate_and_infer_elementwise_arithmetic()
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) != element::boolean) auto args_et_pshape = validate_and_infer_elementwise_args();
<< "Arguments cannot have boolean element type (argument element type: " element::Type& args_et = std::get<0>(args_et_pshape);
<< get_input_element_type(0) << ")."; PartialShape& args_pshape = std::get<1>(args_et_pshape);
validate_and_infer_elementwise(get_input_element_type(0));
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
set_output_type(0, args_et, args_pshape);
} }
void Node::validate_and_infer_elementwise_logical() void Node::validate_and_infer_elementwise_logical()
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == element::boolean) auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type " << "Operands for logical operators must have boolean element type but have element type "
<< get_input_element_type(0) << "."; << args_et << ".";
validate_and_infer_elementwise(get_input_element_type(0));
set_output_type(0, element::boolean, args_pshape);
}
bool Node::validate_punt_if_dynamic()
{
bool any_dynamic = false;
for (auto& input : m_inputs)
{
any_dynamic |= input.get_partial_shape().is_dynamic();
any_dynamic |= input.get_element_type().is_dynamic();
}
if (any_dynamic)
{
for (size_t i = 0; i < get_output_size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
return true;
}
else
{
return false;
}
} }
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <tuple>
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -94,14 +95,21 @@ namespace ngraph ...@@ -94,14 +95,21 @@ namespace ngraph
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
void validate_and_infer_elementwise(element::Type result_type); std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args();
void validate_and_infer_elementwise()
{
validate_and_infer_elementwise(get_input_element_type(0));
}
void validate_and_infer_elementwise_arithmetic(); void validate_and_infer_elementwise_arithmetic();
void validate_and_infer_elementwise_logical(); void validate_and_infer_elementwise_logical();
// Temporary hack while partial shape propagation is being implemented. If any input has
// dynamic shape or dynamic element type, sets all outputs to have a shape of dynamic
// rank and dynamic element type. Ops where we haven't yet implemented partial shape
// propagation can add this boilerplate at the top of their validate_and_infer_types():
//
// if (validate_punt_if_dynamic())
// {
// return;
// }
bool validate_punt_if_dynamic();
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1); Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {} virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
...@@ -125,7 +133,9 @@ namespace ngraph ...@@ -125,7 +133,9 @@ namespace ngraph
return std::type_index(typeid(*this)) == std::type_index(typeid(*n)); return std::type_index(typeid(*this)) == std::type_index(typeid(*n));
} }
void set_output_type(size_t i, const element::Type& element_type, const Shape& shape); void set_output_type(size_t i,
const element::Type& element_type,
const PartialShape& pshape);
bool is_parameter() const; bool is_parameter() const;
virtual bool is_output() const; virtual bool is_output() const;
......
...@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg) ...@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types() void op::AllReduce::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 || get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64) get_input_element_type(0) == element::f64)
......
...@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types() void op::AvgPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
...@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types() void op::AvgPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0); auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
......
...@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types() void ngraph::op::BatchNorm::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
m_bn_input_shape = get_input_shape(INPUT); m_bn_input_shape = get_input_shape(INPUT);
NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2) NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape << "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
...@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps, ...@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
void ngraph::op::BatchNormBackprop::validate_and_infer_types() void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
set_output_size(3); set_output_size(3);
NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4) NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
......
...@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg, ...@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types() void op::Broadcast::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
infer_shape(); infer_shape();
Shape target_shape = m_shape; Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
......
...@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0); Shape first_input_shape = get_input_shape(0);
......
...@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
void op::Convolution::validate_and_infer_types() void op::Convolution::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& data_batch_shape = get_input_shape(0); auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0); auto& data_batch_et = get_input_element_type(0);
auto& filters_shape = get_input_shape(1); auto& filters_shape = get_input_shape(1);
...@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha ...@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types() void op::ConvolutionBackpropData::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
...@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( ...@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types() void op::ConvolutionBackpropFilters::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
......
...@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input, ...@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input,
void op::Dequantize::validate_and_infer_types() void op::Dequantize::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
enum enum
{ {
INPUT, INPUT,
......
...@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0, ...@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0,
void op::Dot::validate_and_infer_types() void op::Dot::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1); auto& input_1 = get_inputs().at(1);
......
...@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
...@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types() void op::MaxPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto forward_arg_et = get_input_element_type(0); auto forward_arg_et = get_input_element_type(0);
auto& forward_arg_shape = get_input_shape(0); auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1); auto delta_et = get_input_element_type(1);
......
...@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg) ...@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::Not::validate_and_infer_types() void op::Not::validate_and_infer_types()
{ {
validate_and_infer_elementwise(); auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
} }
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -22,11 +22,11 @@ using namespace std; ...@@ -22,11 +22,11 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::Parameter::Parameter(const element::Type& element_type, op::Parameter::Parameter(const element::Type& element_type,
const Shape& shape, const PartialShape& pshape,
const bool cacheable) const bool cacheable)
: Op("Parameter", {}) : Op("Parameter", {})
, m_cacheable(cacheable) , m_cacheable(cacheable)
, m_shape(shape) , m_partial_shape(pshape)
, m_element_type(element_type) , m_element_type(element_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type, ...@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type,
void op::Parameter::validate_and_infer_types() void op::Parameter::validate_and_infer_types()
{ {
Op::validate_and_infer_types(); Op::validate_and_infer_types();
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_partial_shape);
} }
shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Parameter>(m_element_type, m_shape); return make_shared<Parameter>(m_element_type, m_partial_shape);
} }
void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -38,10 +38,10 @@ namespace ngraph ...@@ -38,10 +38,10 @@ namespace ngraph
/// \brief Constructions a tensor view-typed parameter node. /// \brief Constructions a tensor view-typed parameter node.
/// ///
/// \param element_type The element type of the parameter. /// \param element_type The element type of the parameter.
/// \param shape The shape of the parameter. /// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated. /// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type, Parameter(const ngraph::element::Type& element_type,
const Shape& shape, const PartialShape& pshape,
const bool cacheable = false); const bool cacheable = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
protected: protected:
bool m_cacheable; bool m_cacheable;
Shape m_shape; PartialShape m_partial_shape;
element::Type m_element_type; element::Type m_element_type;
}; };
} }
......
...@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input, ...@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input,
void op::Quantize::validate_and_infer_types() void op::Quantize::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
enum enum
{ {
INPUT, INPUT,
......
...@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg, ...@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg,
void op::Reshape::validate_and_infer_types() void op::Reshape::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
......
...@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types() ...@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types()
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
set_placement(get_argument(0)->get_placement()); set_placement(get_argument(0)->get_placement());
set_output_type(0, get_input_element_type(0), get_input_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes) ...@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types() void op::Reverse::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0); auto input_shape = get_input_shape(0);
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
......
...@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, ...@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types() void op::ReverseSequence::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1) NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: " << "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_shape(1) << ")."; << get_input_shape(1) << ").";
......
...@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types() void op::Slice::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size()) if (0 == m_strides.size())
{ {
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
......
...@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg, ...@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg,
void op::TopK::validate_and_infer_types() void op::TopK::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
auto rank = input.get_shape().size(); auto rank = input.get_shape().size();
......
...@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type, ...@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0); auto input_shape = get_input_shape(0);
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
......
...@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& ...@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string&
void op::util::BinaryElementwiseComparison::validate_and_infer_types() void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{ {
validate_and_infer_elementwise(element::boolean); auto args_et_pshape = validate_and_infer_elementwise_args();
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);
} }
...@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape) ...@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape)
m_dimensions.assign(shape.begin(), shape.end()); m_dimensions.assign(shape.begin(), shape.end());
} }
bool ngraph::PartialShape::is_complete() const bool ngraph::PartialShape::is_static() const
{ {
return m_rank_is_determined && return m_rank_is_static && std::all_of(m_dimensions.begin(),
std::all_of(m_dimensions.begin(), m_dimensions.end(), [](const Dimension& d) { m_dimensions.end(),
return d.is_determined(); [](const Dimension& d) { return d.is_static(); });
});
} }
PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{ {
if (!s1.rank_is_determined() || !s2.rank_is_determined()) if (s1.rank().is_dynamic() || s2.rank().is_dynamic())
{ {
return PartialShape::undetermined(); return PartialShape::dynamic();
} }
if (!s1.rank().compatible(s2.rank())) if (!s1.rank().compatible(s2.rank()))
...@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) ...@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
} }
PartialShape result{}; PartialShape result{};
result.m_rank_is_determined = true; result.m_rank_is_static = true;
for (size_t i = 0; i < s1.m_dimensions.size(); i++) for (size_t i = 0; i < s1.m_dimensions.size(); i++)
{ {
result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]); result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]);
...@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) ...@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape) std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{ {
if (shape.m_rank_is_determined) if (shape.m_rank_is_static)
{ {
str << "{"; str << "{";
bool first = true; bool first = true;
...@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape) ...@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
bool PartialShape::compatible(const PartialShape& s) const bool PartialShape::compatible(const PartialShape& s) const
{ {
// If we don't know *this's rank, or we don't know s's rank, they are compatible. // If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined()) if (!m_rank_is_static || s.rank().is_dynamic())
{ {
return true; return true;
} }
...@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const ...@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const
} }
} }
bool PartialShape::same_scheme(const PartialShape& s) const
{
if (rank().is_dynamic() && s.rank().is_dynamic())
{
return true;
}
else if (rank().is_static() && s.rank().is_static())
{
if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
bool success = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
success &= (*this)[i].same_scheme(s[i]);
}
return success;
}
else
{
return false;
}
}
Shape PartialShape::to_shape() const Shape PartialShape::to_shape() const
{ {
if (!is_complete()) if (is_dynamic())
{ {
throw std::invalid_argument("to_shape was called on an incomplete shape."); throw std::invalid_argument("to_shape was called on a dynamic shape.");
} }
return Shape(m_dimensions.begin(), m_dimensions.end()); return Shape(m_dimensions.begin(), m_dimensions.end());
} }
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
{
if (dst.rank().is_dynamic())
{
dst = src;
return true;
}
else if (src.rank().is_dynamic())
{
// No change to dst.
return true;
}
else if (size_t(dst.rank()) != size_t(src.rank()))
{
// Mismatching static ranks, cannot merge.
return false;
}
else
{
// Ranks are both static, and they match.
bool success = true;
for (size_t i = 0; i < size_t(dst.rank()); i++)
{
success &= Dimension::merge(dst[i], dst[i], src[i]);
}
return success;
}
}
...@@ -24,93 +24,200 @@ ...@@ -24,93 +24,200 @@
namespace ngraph namespace ngraph
{ {
/// \brief Class representing a shape that may only be partially known. /// \brief Class representing a shape that may be partially or totally dynamic.
/// ///
/// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
/// ///
/// A partially-known shape may have: /// A PartialShape may have:
/// ///
/// - Unknown rank. /// \li Dynamic rank. (Informal notation: `?`)
/// - Known rank, but unknown dimensions on some or all axes. /// \li Static rank, but dynamic dimensions on some or all axes.
/// - Known rank, and known dimensions on all axes. /// (Informal notation examples: `{1,2,?,4}`, `{?,?,?}`)
/// \li Static rank, and dynamic dimensions on all axes.
/// (Informal notation examples: `{1,2,3,4}`, `{6}`, `{}`)
class PartialShape class PartialShape
{ {
public: public:
/// \brief Constructs a shape with determined rank. /// \brief Constructs a shape with static rank from an initializer list of Dimension.
/// \param init The Dimension values for the constructed shape.
/// ///
/// Examples: /// Examples:
/// ///
/// PartialShape s{2,3,4}; // rank=3, all dimensions determined /// \code{.cpp}
/// PartialShape s{2,3,4}; // rank=3, all dimensions static
/// PartialShape s{}; // rank=0 /// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::undetermined(),3}; // rank=2, dimension 1 undetermined /// PartialShape s{2,Dimension::dynamic(),3}; // rank=2, dimension 1 dynamic
/// \endcode
PartialShape(std::initializer_list<Dimension> init) PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init) : PartialShape(true, init)
{ {
} }
/// \brief Constructs a complete PartialShape from a Shape. /// \brief Constructs a PartialShape with static rank from a vector of Dimension.
PartialShape(const Shape& shape); /// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions)
/// \brief Returns true if the shape has determined rank. : m_rank_is_static(true)
bool rank_is_determined() const { return m_rank_is_determined; } , m_dimensions(dimensions)
/// \brief Returns true if the shape has known rank and all dimensions of the shape {
/// are determined. }
bool is_complete() const;
/// \brief Returns the rank of the shape. Returns Rank::undetermined() if the rank is undetermined. /// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
Rank rank() const PartialShape()
: PartialShape({})
{ {
return m_rank_is_determined ? Rank(m_dimensions.size()) : Rank::undetermined();
} }
/// \brief Appends another shape to this shape. /// \brief Constructs a static PartialShape from a Shape.
/// \param shape The Shape to convert into PartialShape.
PartialShape(const Shape& shape);
/// \brief Check if this shape is static.
/// \return `true` if this shape is static, else `false`.
/// ///
/// If "this" and "other" both have determined rank, returns a new shape two shape /// A shape is considered static if it has static rank, and all dimensions of the shape
/// whose dimensions are the concatenation of the dimensions of "this" and "other". /// are static.
/// If either "this" or "other" has undetermined rank, returns bool is_static() const;
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other);
/// \brief Returns the undetermined shape. /// \brief Check if this shape is dynamic.
static PartialShape undetermined() { return PartialShape(false, {}); } /// \return `false` if this shape is static, else `true`.
/// \brief Returns true if *this is compatible with s. ///
/// A shape is considered static if it has static rank, and all dimensions of the shape
/// are static.
bool is_dynamic() const { return !is_static(); }
/// \brief Get the rank of the shape.
/// \return The rank of the shape. This will be Rank::dynamic() if the rank of
/// the shape is dynamic.
Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); }
/// \brief Construct a PartialShape with dynamic rank.
/// \return A PartialShape with dynamic rank.
static PartialShape dynamic() { return PartialShape(false, {}); }
/// \brief Check whether this shape is compatible with the argument, i.e., whether it is
/// possible to merge them.
/// \param s The shape to be checked for compatibility with this shape.
/// \return `true` if this shape is compatible with `s`, else `false`.
/// ///
/// Two dimensions are compatible if one or both of them is undetermined, or if /// Two shapes are compatible if
/// they are both determined and equal. /// \li one or both of them has dynamic rank, or
/// \li both shapes have dynamic and equal rank, and their dimensions are elementwise
/// compatible (see Dimension::compatible()).
bool compatible(const PartialShape& s) const; bool compatible(const PartialShape& s) const;
/// \brief Converts a complete PartialShape to a Shape. /// \brief Check whether this shape represents the same scheme as the argument.
/// \param s The shape whose scheme is being compared with this shape.
/// \return `true` if this shape represents the same scheme as `s`, else `false`.
/// ///
/// Throws std::invalid_argument if the PartialShape is incomplete. /// Two shapes `s1` and `s2` represent the same scheme if
/// \li they both have dynamic rank, or
/// \li they both have static and equal rank `r`, and for every `i` from `0` to `r-1`,
/// `s1[i]` represents the same scheme as `s2[i]` (see Dimension::same_scheme()).
bool same_scheme(const PartialShape& s) const;
/// \brief Convert a static PartialShape to a Shape.
/// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
/// \throws std::invalid_argument If this PartialShape is dynamic.
Shape to_shape() const; Shape to_shape() const;
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
const Dimension& operator[](size_t i) const { return m_dimensions[i]; }
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
Dimension& operator[](size_t i) { return m_dimensions[i]; }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape); friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2); friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Try to merge one shape into another.
/// \param[in,out] dst The shape that `src` will be merged into.
/// \param src The shape that will be merged into `dst`.
/// \return `true` if merging succeeds, else `false`.
///
/// Merges `src` into `dst`, returning `true` on success and `false` on failure. If
/// `false` is returned, the effect on `dst` is unspecified.
///
/// To merge two partial shapes `s1` and `s2` is to find the most permissive partial shape
/// `s` that is no more permissive than `s1` or `s2`, if `s` exists. For example:
///
/// \code
/// merge(?,?) -> ?
/// merge(?,{?,?}) -> {?,?}
/// merge({?,?},{?,?}) -> {?,?}
/// merge({1,2,3,4},?) -> {1,2,3,4}
/// merge({1,2},{1,?}) -> {1,2}
/// merge({1,2,?,?},{1,?,3,?}) -> {1,2,3,?}
/// merge({1,2,3},{1,2,3}) -> {1,2,3}
///
/// merge({1,?},{2,?}) fails [dimension 0 constraints are inconsistent]
/// merge({?,?},{?,?,?}) fails [ranks are inconsistent]
/// \endcode
///
/// This function (merge_into) performs the "merge" operation described above on `dst` and
/// `src`, but overwrites `dst` with the result and returns `true` if merging is
/// successful; if merging is unsuccessful, the function returns `false` and may make
/// unspecified changes to `dst`.
static bool merge_into(PartialShape& dst, const PartialShape& src);
private: private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape. // Private constructor so PartialShape::dynamic() can construct a shape with
PartialShape(bool rank_is_determined, std::initializer_list<Dimension> init) // m_rank_is_static set to false.
: m_rank_is_determined(rank_is_determined) PartialShape(bool rank_is_static, std::initializer_list<Dimension> init)
: m_rank_is_static(rank_is_static)
, m_dimensions(init) , m_dimensions(init)
{ {
} }
// True if the shape's rank is determined. // True if the shape's rank is static.
bool m_rank_is_determined; bool m_rank_is_static;
// Shape dimensions. This has no meaning if m_rank_is_determined is false. // Shape dimensions. This has no meaning if m_rank_is_static is false.
std::vector<Dimension> m_dimensions; std::vector<Dimension> m_dimensions;
}; };
/// \brief Elementwise addition of two shapes. /// \brief Elementwise addition of two PartialShape objects.
/// \param s1 Left operand for addition.
/// \param s2 Right operand for addition.
/// \return The result of elementwise adding `s1` to `s2` (see description).
/// \throws std::invalid_argument If `s1` and `s2` have inconsistent ranks.
/// ///
/// If s1 or s2 has undetermined rank, returns PartialShape::undetermined(). /// \li If `s1` or `s2` has dynamic rank, returns PartialShape::dynamic().
/// If s1 and s2 both have determined rank, and their ranks are unequal, /// \li If `s1 and `s2` both have static rank, and their ranks are unequal, throws
/// throws std::invalid_argument. /// std::invalid_argument.
/// If s1 and s2 both have determined rank, and their ranks are equal, /// \li If `s1` and `s2` both have static rank, and their ranks are equal,
/// returns a new shape whose ith dimension is s1[i] + s2[i]. /// returns a new shape whose `i`th dimension is `s1[i] + s2[i]`.
PartialShape operator+(const PartialShape& s1, const PartialShape& s2); PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Inserts a human-readable representation of "shape" into "str". /// \brief Inserts a human-readable representation of a PartialShape into an output stream.
/// \param str The output stream targeted for insertion.
/// \param shape The shape to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// The output to the stream is in "informal" notation. In other words:
///
/// \li If `shape` has dynamic rank, inserts the string `?`.
/// \li If `shape` has static rank, inserts the string `{`, then inserts each dimension
/// of `shape` into the output stream separated by commas, then inserts `}`.
///
/// Example:
///
/// \code{.cpp}
/// PartialShape s1{PartialShape::dynamic())};
/// PartialShape s2{};
/// PartialShape s3{1,Dimension::dynamic(),2,3};
/// PartialShape s4{2,3,4};
/// std::cout << s1 << std::endl
/// << s2 << std::endl
/// << s3 << std::endl
/// << s4 << std::endl;
/// \endcode
///
/// Output:
///
/// \code
/// ?
/// {}
/// {1,?,2,3}
/// {2,3,4}
/// \endcode
std::ostream& operator<<(std::ostream& str, const PartialShape& shape); std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
} }
...@@ -20,9 +20,8 @@ ...@@ -20,9 +20,8 @@
namespace ngraph namespace ngraph
{ {
/// \brief Alias for "Dimension". Should be used to when the value represents the number of /// \brief Alias for Dimension, used when the value represents the number of axes in a shape,
/// axes in a shape-like object, rather than the size of one dimension in a shape-like /// rather than the size of one dimension in a shape.
/// object.
/// ///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension; using Rank = Dimension;
......
...@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data); ...@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data);
static string static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data); serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
static json write_dimension(Dimension d)
{
if (d.is_static())
{
return size_t(d);
}
else
{
return nullptr;
}
}
static json write_partial_shape(const PartialShape& s)
{
if (s.rank().is_dynamic())
{
return nullptr;
}
else
{
std::vector<json> vals(size_t(s.rank()));
for (size_t i = 0; i < vals.size(); i++)
{
vals[i] = write_dimension(s[i]);
}
return vals;
}
}
static PartialShape read_partial_shape(const json& j)
{
if (j.is_null())
{
return PartialShape::dynamic();
}
else
{
std::vector<Dimension> dims(j.size());
for (size_t i = 0; i < j.size(); i++)
{
if (j[i].is_null())
{
dims[i] = Dimension::dynamic();
}
else
{
dims[i] = size_t(j[i]);
}
}
return PartialShape(dims);
}
}
static json write_element_type(const ngraph::element::Type& n) static json write_element_type(const ngraph::element::Type& n)
{ {
json j; json j;
...@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function> ...@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false); auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable); node =
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break; break;
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
...@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Parameter: case OP_TYPEID::Parameter:
{ {
auto tmp = dynamic_cast<const op::Parameter*>(&n); auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape(); node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["cacheable"] = tmp->get_cacheable(); node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type()); node["element_type"] = write_element_type(tmp->get_element_type());
break; break;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace ngraph; using namespace ngraph;
const element::Type element::unspecified(0, false, false, false, "unspecified"); const element::Type element::dynamic(0, false, false, false, "dynamic");
const element::Type element::boolean(8, false, true, false, "char"); const element::Type element::boolean(8, false, true, false, "char");
const element::Type element::f32(32, true, true, false, "float"); const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double"); const element::Type element::f64(64, true, true, false, "double");
...@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) ...@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
<< ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}"; << ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}";
return out; return out;
} }
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{
if (t1.is_dynamic())
{
dst = t2;
return true;
}
else if (t2.is_dynamic())
{
dst = t1;
return true;
}
else if (t1 == t2)
{
dst = t1;
return true;
}
else
{
return false;
}
}
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
{ {
class Type; class Type;
extern const Type unspecified; extern const Type dynamic;
extern const Type boolean; extern const Type boolean;
extern const Type f32; extern const Type f32;
extern const Type f64; extern const Type f64;
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
size_t hash() const; size_t hash() const;
bool is_static() const { return (*this != dynamic); }
bool is_dynamic() const { return !is_static(); }
bool is_real() const { return m_is_real; } bool is_real() const { return m_is_real; }
bool is_signed() const { return m_is_signed; } bool is_signed() const { return m_is_signed; }
bool is_quantized() const { return m_is_quantized; } bool is_quantized() const { return m_is_quantized; }
...@@ -73,12 +75,32 @@ namespace ngraph ...@@ -73,12 +75,32 @@ namespace ngraph
/// Returns true if the type is floating point, else false. /// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; } bool get_is_real() const { return m_is_real; }
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
///
/// To "merge" two element types t1 and t2 is to find the least restrictive
/// element type t that is no more restrictive than t1 and t2, if t exists.
/// More simply:
///
/// merge(dst,element::Type::dynamic,t)
/// writes t to dst and returns true
///
/// merge(dst,t,element::Type::dynamic)
/// writes t to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and equal
/// writes t1 to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and unequal
/// does nothing to dst, and returns false
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
private: private:
size_t m_bitwidth{0}; size_t m_bitwidth{0};
bool m_is_real{false}; bool m_is_real{false};
bool m_is_signed{false}; bool m_is_signed{false};
bool m_is_quantized{false}; bool m_is_quantized{false};
std::string m_cname{"unspecified"}; std::string m_cname{"dynamic"};
}; };
template <typename T> template <typename T>
......
...@@ -84,3 +84,42 @@ TEST(element_type, size) ...@@ -84,3 +84,42 @@ TEST(element_type, size)
EXPECT_EQ(2, t1.size()); EXPECT_EQ(2, t1.size());
} }
} }
TEST(element_type, merge_both_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::dynamic));
ASSERT_TRUE(t.is_dynamic());
}
TEST(element_type, merge_left_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::u64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::u64);
}
TEST(element_type, merge_right_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::i16, element::dynamic));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::i16);
}
TEST(element_type, merge_both_static_equal)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::f64, element::f64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f64);
}
TEST(element_type, merge_both_static_unequal)
{
element::Type t = element::f32;
ASSERT_FALSE(element::Type::merge(t, element::i8, element::i16));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f32);
}
...@@ -24,110 +24,106 @@ using namespace ngraph; ...@@ -24,110 +24,106 @@ using namespace ngraph;
TEST(partial_shape, ps_construction_empty) TEST(partial_shape, ps_construction_empty)
{ {
auto ps = PartialShape{}; auto ps = PartialShape{};
ASSERT_TRUE(ps.rank_is_determined()); ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.rank().is_determined()); ASSERT_TRUE(ps.is_static());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 0); ASSERT_EQ(size_t(ps.rank()), 0);
} }
TEST(partial_shape, ps_construction_undetermined) TEST(partial_shape, ps_construction_rank_dynamic)
{ {
auto ps = PartialShape::undetermined(); auto ps = PartialShape::dynamic();
ASSERT_FALSE(ps.rank_is_determined()); ASSERT_TRUE(ps.rank().is_dynamic());
ASSERT_FALSE(ps.rank().is_determined()); ASSERT_TRUE(ps.is_dynamic());
ASSERT_FALSE(ps.is_complete());
} }
TEST(partial_shape, ps_construction_incomplete) TEST(partial_shape, ps_construction_rank_static_shape_dynamic)
{ {
auto ps = PartialShape{2, Dimension::undetermined(), 3}; auto ps = PartialShape{2, Dimension::dynamic(), 3};
ASSERT_TRUE(ps.rank_is_determined()); ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.rank().is_determined()); ASSERT_TRUE(ps.is_dynamic());
ASSERT_FALSE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 3); ASSERT_EQ(size_t(ps.rank()), 3);
} }
TEST(partial_shape, ps_construction_complete) TEST(partial_shape, ps_construction_static)
{ {
auto ps = PartialShape{2, 5, 3, 6}; auto ps = PartialShape{2, 5, 3, 6};
ASSERT_TRUE(ps.rank_is_determined()); ASSERT_TRUE(ps.rank().is_static());
ASSERT_TRUE(ps.rank().is_determined()); ASSERT_TRUE(ps.is_static());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 4); ASSERT_EQ(size_t(ps.rank()), 4);
} }
TEST(partial_shape, dim_construction_determined) TEST(partial_shape, dim_construction_static)
{ {
Dimension dim{3}; Dimension dim{3};
ASSERT_EQ(size_t(dim), 3); ASSERT_EQ(size_t(dim), 3);
ASSERT_TRUE(dim.is_determined()); ASSERT_TRUE(dim.is_static());
} }
TEST(partial_shape, dim_construction_undetermined) TEST(partial_shape, dim_construction_dynamic)
{ {
Dimension dim = Dimension::undetermined(); Dimension dim = Dimension::dynamic();
ASSERT_FALSE(dim.is_determined()); ASSERT_TRUE(dim.is_dynamic());
} }
TEST(partial_shape, dim_construction_size_t_max) TEST(partial_shape, dim_construction_size_t_max)
{ {
EXPECT_ANY_THROW({ Dimension d{Dimension::s_undetermined_val}; }); EXPECT_ANY_THROW({ Dimension d{Dimension::s_dynamic_val}; });
} }
TEST(partial_shape, dim_conversion_determined) TEST(partial_shape, dim_conversion_static)
{ {
Dimension d{42}; Dimension d{42};
size_t s{d}; size_t s{d};
ASSERT_EQ(s, 42); ASSERT_EQ(s, 42);
} }
TEST(partial_shape, dim_conversion_undetermined) TEST(partial_shape, dim_conversion_dynamic)
{ {
EXPECT_ANY_THROW({ EXPECT_ANY_THROW({
size_t s{Dimension::undetermined()}; size_t s{Dimension::dynamic()};
s = 0; // Silence compiler warning about unused s s = 0; // Silence compiler warning about unused s
}); });
} }
TEST(partial_shape, rank_construction_determined) TEST(partial_shape, rank_construction_static)
{ {
Rank r{4}; Rank r{4};
ASSERT_EQ(size_t(r), 4); ASSERT_EQ(size_t(r), 4);
ASSERT_TRUE(r.is_determined()); ASSERT_TRUE(r.is_static());
} }
TEST(partial_shape, rank_construction_undetermined) TEST(partial_shape, rank_construction_dynamic)
{ {
Rank r = Rank::undetermined(); Rank r = Rank::dynamic();
ASSERT_FALSE(r.is_determined()); ASSERT_TRUE(r.is_dynamic());
} }
TEST(partial_shape, dim_compatible_left_undetermined) TEST(partial_shape, dim_compatible_left_dynamic)
{ {
Dimension d1{Dimension::undetermined()}; Dimension d1{Dimension::dynamic()};
Dimension d2{3}; Dimension d2{3};
ASSERT_TRUE(d1.compatible(d2)); ASSERT_TRUE(d1.compatible(d2));
} }
TEST(partial_shape, dim_compatible_right_undetermined) TEST(partial_shape, dim_compatible_right_dynamic)
{ {
Dimension d1{3}; Dimension d1{3};
Dimension d2{Dimension::undetermined()}; Dimension d2{Dimension::dynamic()};
ASSERT_TRUE(d1.compatible(d2)); ASSERT_TRUE(d1.compatible(d2));
} }
TEST(partial_shape, dim_compatible_both_undetermined) TEST(partial_shape, dim_compatible_both_dynamic)
{ {
Dimension d1{Dimension::undetermined()}; Dimension d1{Dimension::dynamic()};
Dimension d2{Dimension::undetermined()}; Dimension d2{Dimension::dynamic()};
ASSERT_TRUE(d1.compatible(d2)); ASSERT_TRUE(d1.compatible(d2));
} }
TEST(partial_shape, dim_compatible_both_determined) TEST(partial_shape, dim_compatible_both_static)
{ {
Dimension d1{3}; Dimension d1{3};
Dimension d2{8}; Dimension d2{8};
...@@ -137,25 +133,25 @@ TEST(partial_shape, dim_compatible_both_determined) ...@@ -137,25 +133,25 @@ TEST(partial_shape, dim_compatible_both_determined)
ASSERT_TRUE(d1.compatible(d3)); ASSERT_TRUE(d1.compatible(d3));
} }
TEST(partial_shape, shapes_compatible_both_rank_undetermined) TEST(partial_shape, shapes_compatible_both_rank_dynamic)
{ {
PartialShape ps1{PartialShape::undetermined()}; PartialShape ps1{PartialShape::dynamic()};
PartialShape ps2{PartialShape::undetermined()}; PartialShape ps2{PartialShape::dynamic()};
ASSERT_TRUE(ps1.compatible(ps2)); ASSERT_TRUE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_compatible_left_rank_undetermined) TEST(partial_shape, shapes_compatible_left_rank_dynamic)
{ {
PartialShape ps1{3}; PartialShape ps1{3};
PartialShape ps2{PartialShape::undetermined()}; PartialShape ps2{PartialShape::dynamic()};
ASSERT_TRUE(ps1.compatible(ps2)); ASSERT_TRUE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_compatible_right_rank_undetermined) TEST(partial_shape, shapes_compatible_right_rank_dynamic)
{ {
PartialShape ps1{PartialShape::undetermined()}; PartialShape ps1{PartialShape::dynamic()};
PartialShape ps2{4}; PartialShape ps2{4};
ASSERT_TRUE(ps1.compatible(ps2)); ASSERT_TRUE(ps1.compatible(ps2));
...@@ -163,21 +159,21 @@ TEST(partial_shape, shapes_compatible_right_rank_undetermined) ...@@ -163,21 +159,21 @@ TEST(partial_shape, shapes_compatible_right_rank_undetermined)
TEST(partial_shape, shapes_compatible_both_partial_all_known_equal) TEST(partial_shape, shapes_compatible_both_partial_all_known_equal)
{ {
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5}; PartialShape ps1{2, Dimension::dynamic(), 3, Dimension::dynamic(), 5};
PartialShape ps2{2, Dimension::undetermined(), Dimension::undetermined(), 4, 5}; PartialShape ps2{2, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
ASSERT_TRUE(ps1.compatible(ps2)); ASSERT_TRUE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_compatible_both_partial_some_known_unequal) TEST(partial_shape, shapes_compatible_both_partial_some_known_unequal)
{ {
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5}; PartialShape ps1{2, Dimension::dynamic(), 3, Dimension::dynamic(), 5};
PartialShape ps2{1, Dimension::undetermined(), Dimension::undetermined(), 4, 5}; PartialShape ps2{1, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
ASSERT_FALSE(ps1.compatible(ps2)); ASSERT_FALSE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_compatible_both_complete_different_rank) TEST(partial_shape, shapes_compatible_both_static_different_rank)
{ {
PartialShape ps1{2, 4, 6, 8}; PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8, 10}; PartialShape ps2{2, 4, 6, 8, 10};
...@@ -185,7 +181,7 @@ TEST(partial_shape, shapes_compatible_both_complete_different_rank) ...@@ -185,7 +181,7 @@ TEST(partial_shape, shapes_compatible_both_complete_different_rank)
ASSERT_FALSE(ps1.compatible(ps2)); ASSERT_FALSE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims) TEST(partial_shape, shapes_equal_both_static_same_rank_same_dims)
{ {
PartialShape ps1{2, 4, 6, 8}; PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8}; PartialShape ps2{2, 4, 6, 8};
...@@ -193,7 +189,7 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims) ...@@ -193,7 +189,7 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
ASSERT_TRUE(ps1.compatible(ps2)); ASSERT_TRUE(ps1.compatible(ps2));
} }
TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims) TEST(partial_shape, shapes_equal_both_static_same_rank_different_dims)
{ {
PartialShape ps1{2, 4, 6, 8}; PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 3, 8}; PartialShape ps2{2, 4, 3, 8};
...@@ -206,13 +202,16 @@ TEST(partial_shape, from_shape) ...@@ -206,13 +202,16 @@ TEST(partial_shape, from_shape)
Shape s{2, 4, 6, 8}; Shape s{2, 4, 6, 8};
PartialShape ps1{s}; PartialShape ps1{s};
// TODO(amprocte): No way to examine contents of ps1 yet. ASSERT_TRUE(ps1.rank().is_static());
ASSERT_TRUE(ps1.is_complete());
ASSERT_TRUE(ps1.rank_is_determined());
ASSERT_EQ(size_t(ps1.rank()), s.size()); ASSERT_EQ(size_t(ps1.rank()), s.size());
ASSERT_TRUE(ps1.is_static());
ASSERT_EQ(size_t(ps1[0]), 2);
ASSERT_EQ(size_t(ps1[1]), 4);
ASSERT_EQ(size_t(ps1[2]), 6);
ASSERT_EQ(size_t(ps1[3]), 8);
} }
TEST(partial_shape, to_shape_complete) TEST(partial_shape, to_shape_static)
{ {
PartialShape ps{2, 4, 6, 8}; PartialShape ps{2, 4, 6, 8};
Shape s{ps.to_shape()}; Shape s{ps.to_shape()};
...@@ -220,15 +219,15 @@ TEST(partial_shape, to_shape_complete) ...@@ -220,15 +219,15 @@ TEST(partial_shape, to_shape_complete)
ASSERT_EQ(s, (Shape{2, 4, 6, 8})); ASSERT_EQ(s, (Shape{2, 4, 6, 8}));
} }
TEST(partial_shape, to_shape_dims_undetermined) TEST(partial_shape, to_shape_dims_dynamic)
{ {
PartialShape ps{2, 4, Dimension::undetermined(), 8}; PartialShape ps{2, 4, Dimension::dynamic(), 8};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument); ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
} }
TEST(partial_shape, to_shape_rank_undetermined) TEST(partial_shape, to_shape_rank_dynamic)
{ {
PartialShape ps{PartialShape::undetermined()}; PartialShape ps{PartialShape::dynamic()};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument); ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
} }
...@@ -238,28 +237,250 @@ TEST(partial_shape, tensor_descriptor_from_shape) ...@@ -238,28 +237,250 @@ TEST(partial_shape, tensor_descriptor_from_shape)
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3})); ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3); ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, 2, 3}));
} }
TEST(partial_shape, tensor_descriptor_from_complete_partial_shape) TEST(partial_shape, tensor_descriptor_from_static_partial_shape)
{ {
descriptor::Tensor t{element::i32, PartialShape{1, 2, 3}, "Burnside"}; descriptor::Tensor t{element::i32, PartialShape{1, 2, 3}, "Burnside"};
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3})); ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3); ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, 2, 3}));
} }
TEST(partial_shape, tensor_descriptor_from_incomplete_partial_shape) TEST(partial_shape, tensor_descriptor_from_rank_static_dynamic_partial_shape)
{ {
descriptor::Tensor t{element::i32, PartialShape{1, Dimension::undetermined(), 3}, "Couch"}; descriptor::Tensor t{element::i32, PartialShape{1, Dimension::dynamic(), 3}, "Couch"};
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3); ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument); ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
} }
TEST(partial_shape, tensor_descriptor_from_rankless_partial_shape) TEST(partial_shape, tensor_descriptor_from_rank_dynamic_partial_shape)
{ {
descriptor::Tensor t{element::i32, PartialShape::undetermined(), "Davis"}; descriptor::Tensor t{element::i32, PartialShape::dynamic(), "Davis"};
ASSERT_FALSE(t.get_partial_shape().rank().is_determined()); ASSERT_TRUE(t.get_partial_shape().rank().is_dynamic());
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument); ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
ASSERT_TRUE(t.get_partial_shape().same_scheme(PartialShape::dynamic()));
}
TEST(partial_shape, dim_same_scheme_both_dynamic)
{
ASSERT_TRUE(Dimension::dynamic().same_scheme(Dimension::dynamic()));
}
TEST(partial_shape, dim_same_scheme_left_dynamic)
{
ASSERT_FALSE(Dimension::dynamic().same_scheme(6));
}
TEST(partial_shape, dim_same_scheme_right_dynamic)
{
ASSERT_FALSE(Dimension(6).same_scheme(Dimension::dynamic()));
}
TEST(partial_shape, dim_same_scheme_both_static_same)
{
ASSERT_TRUE(Dimension(6).same_scheme(Dimension(6)));
}
TEST(partial_shape, dim_same_scheme_both_static_different)
{
ASSERT_FALSE(Dimension(6).same_scheme(Dimension(7)));
}
TEST(partial_shape, partial_shape_same_scheme_both_dynamic)
{
ASSERT_TRUE(PartialShape::dynamic().same_scheme(PartialShape::dynamic()));
}
TEST(partial_shape, partial_shape_same_scheme_left_dynamic_right_rank_static_dynamic)
{
ASSERT_FALSE(PartialShape::dynamic().same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(partial_shape, partial_shape_same_scheme_left_dynamic_right_static)
{
ASSERT_FALSE(PartialShape::dynamic().same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_same_scheme_right_dynamic_left_rank_static_dynamic)
{
ASSERT_FALSE((PartialShape{1, Dimension::dynamic(), 3}.same_scheme(PartialShape::dynamic())));
}
TEST(partial_shape, partial_shape_same_scheme_right_dynamic_left_static)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape::dynamic())));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_different_rank)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 2, 3, 4})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_different_rank)
{
ASSERT_FALSE((PartialShape{1, Dimension::dynamic(), 3}.same_scheme(
PartialShape{1, Dimension::dynamic(), 3, 4})));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_same_rank_different_dims)
{
ASSERT_FALSE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 3, 3})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_same_rank_different_dims)
{
ASSERT_FALSE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, 3, Dimension::dynamic()})));
}
TEST(partial_shape,
partial_shape_same_scheme_both_rank_static_dynamic_same_rank_compatible_not_same)
{
ASSERT_FALSE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, Dimension::dynamic(), 3})));
}
TEST(partial_shape, partial_shape_same_scheme_both_rank_static_dynamic_same_rank_compatible_same)
{
ASSERT_TRUE((PartialShape{1, 2, Dimension::dynamic()}.same_scheme(
PartialShape{1, 2, Dimension::dynamic()})));
}
TEST(partial_shape, partial_shape_same_scheme_both_static_same_rank_same_dims)
{
ASSERT_TRUE((PartialShape{1, 2, 3}.same_scheme(PartialShape{1, 2, 3})));
}
TEST(partial_shape, partial_shape_same_scheme_scalar)
{
ASSERT_TRUE((PartialShape{}.same_scheme(PartialShape{})));
}
TEST(partial_shape, dim_merge_both_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, Dimension::dynamic(), Dimension::dynamic()));
ASSERT_TRUE(d.is_dynamic());
}
TEST(partial_shape, dim_merge_left_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, Dimension::dynamic(), 3));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_right_dynamic)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, 3, Dimension::dynamic()));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_both_static_equal)
{
Dimension d;
ASSERT_TRUE(Dimension::merge(d, 3, 3));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 3);
}
TEST(partial_shape, dim_merge_both_static_unequal)
{
Dimension d = 163;
ASSERT_FALSE(Dimension::merge(d, 3, 4));
ASSERT_TRUE(d.is_static());
ASSERT_EQ(size_t(d), 163);
}
TEST(partial_shape, partial_shape_merge_both_rank_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.rank().is_dynamic());
}
TEST(partial_shape, partial_shape_merge_left_rank_dynamic_right_rank_static_dynamic)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_left_rank_dynamic_right_static)
{
PartialShape s1{PartialShape::dynamic()};
const PartialShape s2{1, 2, 3};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_left_rank_static_dynamic_right_rank_dynamic)
{
PartialShape s1{1, 2, Dimension::dynamic()};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_left_static_right_rank_dynamic)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{PartialShape::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_consistent)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic(), Dimension::dynamic()};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3, Dimension::dynamic()}));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_same_rank_inconsistent)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{2, 2, Dimension::dynamic(), Dimension::dynamic()};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_rank_static_dynamic_different_rank)
{
PartialShape s1{1, Dimension::dynamic(), 3, Dimension::dynamic()};
const PartialShape s2{1, 2, Dimension::dynamic()};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_static_consistent)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 3};
ASSERT_TRUE(PartialShape::merge_into(s1, s2));
ASSERT_TRUE(s1.same_scheme(PartialShape{1, 2, 3}));
}
TEST(partial_shape, partial_shape_merge_both_static_inconsistent)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 4};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
}
TEST(partial_shape, partial_shape_merge_both_static_different_rank)
{
PartialShape s1{1, 2, 3};
const PartialShape s2{1, 2, 3, 4};
ASSERT_FALSE(PartialShape::merge_into(s1, s2));
} }
...@@ -476,9 +476,7 @@ void test_binary(std::string node_type, ...@@ -476,9 +476,7 @@ void test_binary(std::string node_type,
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
error.what(),
std::string("Argument 0 shape Shape{2, 4} differs in shape from argument 1"));
} }
catch (...) catch (...)
{ {
...@@ -497,10 +495,8 @@ void test_binary(std::string node_type, ...@@ -497,10 +495,8 @@ void test_binary(std::string node_type,
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Argument element types are inconsistent"));
std::string("Argument 0 element type element::Type{32, 1, "
"1, 0, \"float\"} differs in element type from argument 1"));
} }
catch (...) catch (...)
{ {
...@@ -572,9 +568,7 @@ void test_binary_logical(std::string node_type, ...@@ -572,9 +568,7 @@ void test_binary_logical(std::string node_type,
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
error.what(),
std::string("Argument 0 shape Shape{2, 4} differs in shape from argument 1"));
} }
catch (...) catch (...)
{ {
...@@ -593,10 +587,8 @@ void test_binary_logical(std::string node_type, ...@@ -593,10 +587,8 @@ void test_binary_logical(std::string node_type,
} }
catch (const NodeValidationError& error) catch (const NodeValidationError& error)
{ {
EXPECT_HAS_SUBSTRING( EXPECT_HAS_SUBSTRING(error.what(),
error.what(), std::string("Argument element types are inconsistent"));
std::string("Argument 0 element type element::Type{8, 0, 1, 0, \"char\"} "
"differs in element type from argument 1"));
} }
catch (...) catch (...)
{ {
...@@ -624,7 +616,7 @@ void test_binary_logical(std::string node_type, ...@@ -624,7 +616,7 @@ void test_binary_logical(std::string node_type,
}; };
test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2); test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0); test_binary_differ_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3); test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) { auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
...@@ -6598,6 +6590,384 @@ TEST(type_prop, topk_invalid_k) ...@@ -6598,6 +6590,384 @@ TEST(type_prop, topk_invalid_k)
} }
} }
TEST(type_prop, param_partial_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto& pshape = a->get_output_partial_shape(0);
ASSERT_TRUE(pshape.is_dynamic());
ASSERT_TRUE(pshape.rank().is_dynamic());
}
TEST(type_prop, param_partial_rank_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3, 4});
auto& pshape = a->get_output_partial_shape(0);
ASSERT_TRUE(pshape.is_dynamic());
ASSERT_EQ(size_t(pshape.rank()), 4);
ASSERT_TRUE(pshape[0].is_static() && size_t(pshape[0]) == 2);
ASSERT_TRUE(pshape[1].is_dynamic());
ASSERT_TRUE(pshape[2].is_static() && size_t(pshape[2]) == 3);
ASSERT_TRUE(pshape[3].is_static() && size_t(pshape[3]) == 4);
}
TEST(type_prop, binary_elementwise_arithmetic_both_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_dynamic());
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
}
TEST(type_prop,
binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(
type_prop,
binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(
element::f32, PartialShape{1, Dimension::dynamic(), Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
ASSERT_TRUE(
add->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_static_dynamic)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_static)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsistent)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3, 3});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different_rank)
{
auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3, 4});
auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
try
{
auto add = make_shared<op::Add>(a, b);
FAIL() << "Inconsistent partial shapes not detected";
}
catch (const NodeValidationError& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, binary_elementwise_arithmetic_both_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_TRUE(add->get_output_element_type(0).is_dynamic());
}
TEST(type_prop, binary_elementwise_arithmetic_left_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::u32, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_EQ(add->get_output_element_type(0), element::u32);
}
TEST(type_prop, binary_elementwise_arithmetic_right_et_dynamic)
{
auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
auto add = make_shared<op::Add>(a, b);
ASSERT_EQ(add->get_output_element_type(0), element::i64);
}
TEST(type_prop, logic_arith_compare_partial_et)
{
auto test_logic = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::And>(param0, param1);
};
auto test_arith = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::Add>(param0, param1);
};
auto test_compare = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
return std::make_shared<op::Greater>(param0, param1);
};
auto test_not = [](element::Type et) -> std::shared_ptr<Node> {
auto param = std::make_shared<op::Parameter>(et, Shape{1, 2, 3});
return std::make_shared<op::Not>(param);
};
// Logical ops:
//
// int int -> !
// int boo -> !
// int dyn -> !
// boo int -> !
// boo boo -> boo
// boo dyn -> boo
// dyn int -> !
// dyn boo -> boo
// dyn dyn -> boo
ASSERT_ANY_THROW({ test_logic(element::i32, element::i32); });
ASSERT_ANY_THROW({ test_logic(element::i32, element::boolean); });
ASSERT_ANY_THROW({ test_logic(element::i32, element::dynamic); });
ASSERT_ANY_THROW({ test_logic(element::boolean, element::i32); });
ASSERT_EQ(test_logic(element::boolean, element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_logic(element::boolean, element::dynamic)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_logic(element::dynamic, element::i32); });
ASSERT_EQ(test_logic(element::dynamic, element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_logic(element::dynamic, element::dynamic)->get_element_type(), element::boolean);
// Arith ops:
//
// int int -> int
// int boo -> !
// int dyn -> int
// boo int -> !
// boo boo -> !
// boo dyn -> !
// dyn int -> int
// dyn boo -> !
// dyn dyn -> dyn
ASSERT_EQ(test_arith(element::i32, element::i32)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::i32, element::boolean); });
ASSERT_EQ(test_arith(element::i32, element::dynamic)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::boolean, element::i32); });
ASSERT_ANY_THROW({ test_arith(element::boolean, element::boolean); });
ASSERT_ANY_THROW({ test_arith(element::boolean, element::dynamic); });
ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32);
ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); });
ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic);
// Comparison ops:
//
// int int -> boo
// int boo -> !
// int dyn -> boo
// boo int -> !
// boo boo -> boo
// boo dyn -> boo
// dyn int -> boo
// dyn boo -> boo
// dyn dyn -> boo
ASSERT_EQ(test_compare(element::i32, element::i32)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_compare(element::i32, element::boolean); });
ASSERT_EQ(test_compare(element::i32, element::dynamic)->get_element_type(), element::boolean);
ASSERT_ANY_THROW({ test_compare(element::boolean, element::i32); });
ASSERT_EQ(test_compare(element::boolean, element::boolean)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::boolean, element::dynamic)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::i32)->get_element_type(), element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::boolean)->get_element_type(),
element::boolean);
ASSERT_EQ(test_compare(element::dynamic, element::dynamic)->get_element_type(),
element::boolean);
// Logical negation op:
//
// Current behavior:
// int -> int
// boo -> boo
// dyn -> dyn
//
// TODO(amprocte): I believe the behavior should actually be:
// int -> !
// boo -> boo
// dyn -> boo
ASSERT_EQ(test_not(element::i32)->get_element_type(), element::i32);
ASSERT_EQ(test_not(element::boolean)->get_element_type(), element::boolean);
ASSERT_EQ(test_not(element::dynamic)->get_element_type(), element::dynamic);
}
TEST(type_prop, quantize_f32_to_i8_nchw_per_channel_ok) TEST(type_prop, quantize_f32_to_i8_nchw_per_channel_ok)
{ {
Shape batch_shape{64, 3, 480, 640}; Shape batch_shape{64, 3, 480, 640};
...@@ -7543,3 +7913,20 @@ TEST(type_prop, dequantize_offset_shape_mismatch_different_rank_fails) ...@@ -7543,3 +7913,20 @@ TEST(type_prop, dequantize_offset_shape_mismatch_different_rank_fails)
FAIL() << "Deduced type check failed for unexpected reason"; FAIL() << "Deduced type check failed for unexpected reason";
} }
} }
//
// This is testing a temporary hack for ops that do not yet support partial-shape validation.
// The graph we construct here is bogus, but because there is some partiality in the input shapes,
// it should still pass validation but set the output shape and element types to be dynamic.
//
TEST(type_prop, validate_punt_if_dynamic)
{
auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto b = make_shared<op::Parameter>(element::u32, PartialShape{1, Dimension::dynamic(), 3});
auto c = make_shared<op::Parameter>(element::i32, Shape{1, 8, 3});
auto concat = make_shared<op::Concat>(NodeVector{a, b, c}, /*concatenation axis=*/1234);
ASSERT_EQ(concat->get_output_size(), 1);
ASSERT_TRUE(concat->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(concat->get_output_element_type(0).is_dynamic());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment