Commit 0563a3cf authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 3: Framework for partial shape/element type validation (#1728)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item
parent 631d7253
......@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name)
: m_element_type(element_type)
, m_shape(pshape.is_complete() ? pshape.to_shape() : Shape{})
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_name(name)
{
......@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape)
{
if (pshape.is_complete())
if (pshape.is_static())
{
m_shape = pshape.to_shape();
}
......@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const Shape& descriptor::Tensor::get_shape() const
{
if (m_partial_shape.is_complete())
if (m_partial_shape.is_static())
{
return m_shape;
}
else
{
throw std::invalid_argument(
"get_shape was called on a descriptor::Tensor with incomplete shape");
"get_shape was called on a descriptor::Tensor with dynamic shape");
}
}
......
......@@ -25,11 +25,11 @@ using namespace ngraph;
Dimension::Dimension(size_t dimension)
: m_dimension(dimension)
{
if (dimension == s_undetermined_val)
if (dimension == s_dynamic_val)
{
std::stringstream ss;
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_undetermined_val
<< " to Dimension: this value is used internally to represent an undetermined "
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_dynamic_val
<< " to Dimension: this value is used internally to represent a dynamic "
"dimension.";
throw std::invalid_argument(ss.str());
}
......@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension)
std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{
if (dimension.is_determined())
if (dimension.is_static())
{
return (str << size_t(dimension));
}
......@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
{
return (d1.is_determined() && d2.is_determined() ? size_t(d1) + size_t(d2)
: Dimension::undetermined());
return (d1.is_static() && d2.is_static() ? size_t(d1) + size_t(d2) : Dimension::dynamic());
}
bool Dimension::compatible(const Dimension& d) const
{
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d));
return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.is_dynamic())
{
dst = d2;
return true;
}
else if (d2.is_dynamic())
{
dst = d1;
return true;
}
else if (size_t(d1) != size_t(d2))
{
return false;
}
else
{
dst = d1;
return true;
}
}
......@@ -22,56 +22,97 @@
namespace ngraph
{
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object.
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is
/// constructed with Dimension() or Dimension::undetermined().
/// Static dimensions may be implicitly converted from size_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class Dimension
{
public:
/// \brief Constructs a known dimension.
///
/// Requires that dimension != s_undetermined_val. If that condition does not hold,
/// throws std::invalid_argument.
/// \brief Construct a static dimension.
/// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension);
/// \brief Constructs an unknown dimension.
Dimension() { m_dimension = s_undetermined_val; }
/// \brief Returns true if this dimension is determined.
bool is_determined() const { return m_dimension != s_undetermined_val; }
/// \brief Converts this dimension to size_t. If the dimension is undetermined, throws
/// std::invalid_argument.
/// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_dynamic_val; }
/// \brief Check whether this dimension is static.
/// \return `true` if the dimension is static, else `false`.
bool is_static() const { return m_dimension != s_dynamic_val; }
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator size_t() const
{
if (!is_determined())
if (is_dynamic())
{
throw std::invalid_argument("Cannot convert unknown dimension to size_t");
throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
}
return m_dimension;
}
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions
/// is undetermined, or both dimensions are determined and equal.
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const
{
return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim));
}
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension.
static Dimension undetermined() { return Dimension(); }
/// \brief Constant for the value used internally to represent an unknown dimension.
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()};
/// \brief Create a dynamic dimension.
/// \return A dynamic dimension.
static Dimension dynamic() { return Dimension(); }
/// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{std::numeric_limits<size_t>::max()};
private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension.
size_t m_dimension;
};
/// \brief Inserts a human-readable representation of "dimension" into "str".
/// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns
/// Dimension::undetermined().
/// \param d1 Left operand for addition.
/// \param d2 Right operand for addition.
/// \return Dimension::dynamic() if either of `d1` or `d2` is dynamic; else, a static
/// dimension with value `size_t(d1)+size_t(d2)`.
Dimension operator+(const Dimension& d1, const Dimension& d2);
}
......@@ -52,9 +52,7 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
......
......@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n)
for (size_t i = m_outputs.size(); i < n; ++i)
{
auto tensor_descriptor = make_shared<descriptor::Tensor>(
element::unspecified, Shape(), get_name() + "_" + to_string(i));
element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_descriptor);
}
}
......@@ -84,9 +84,9 @@ void Node::validate_and_infer_types()
{
}
void Node::set_output_type(size_t i, const element::Type& element_type, const Shape& shape)
void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, shape);
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
}
std::deque<descriptor::Output>& Node::get_outputs()
......@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const
return out << get_name();
}
static std::string pretty_element_type(const element::Type& et)
{
if (et.is_dynamic())
{
return "?";
}
else
{
return et.c_type_string();
}
}
std::ostream& Node::write_long_description(std::ostream& out) const
{
out << description() << '[' << get_name() << "](";
string sep = "";
for (auto arg : get_arguments())
{
out << sep << NodeDescription(*arg, true);
out << sep << NodeDescription(*arg, true) << ": "
<< pretty_element_type(arg->get_output_element_type(0))
<< arg->get_output_partial_shape(0) << "";
sep = ", ";
}
out << ")";
......@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args;
}
void Node::validate_and_infer_elementwise(element::Type result_type)
std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_args()
{
const element::Type& element_type = get_input_element_type(0);
const Shape& shape = get_input_shape(0);
element::Type element_type = get_input_element_type(0);
PartialShape pshape = get_input_partial_shape(0);
if (get_input_size() > 1)
{
for (size_t i = 1; i < get_input_size(); ++i)
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == element_type)
<< "Argument 0 element type " << element_type
<< " differs in element type from argument " << i << " " << *get_argument(i)
<< " element type " << get_input_element_type(i);
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == shape)
<< "Argument 0 shape " << shape << " differs in shape from argument " << i << " "
<< *get_argument(i) << " shape " << get_input_shape(i);
NODE_VALIDATION_ASSERT(
this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< "Argument element types are inconsistent.";
NODE_VALIDATION_ASSERT(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< "Argument shapes are inconsistent.";
}
}
set_output_type(0, result_type, shape);
return std::make_tuple(element_type, pshape);
}
void Node::validate_and_infer_elementwise_arithmetic()
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: "
<< get_input_element_type(0) << ").";
validate_and_infer_elementwise(get_input_element_type(0));
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
set_output_type(0, args_et, args_pshape);
}
void Node::validate_and_infer_elementwise_logical()
{
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == element::boolean)
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type "
<< get_input_element_type(0) << ".";
validate_and_infer_elementwise(get_input_element_type(0));
<< args_et << ".";
set_output_type(0, element::boolean, args_pshape);
}
bool Node::validate_punt_if_dynamic()
{
bool any_dynamic = false;
for (auto& input : m_inputs)
{
any_dynamic |= input.get_partial_shape().is_dynamic();
any_dynamic |= input.get_element_type().is_dynamic();
}
if (any_dynamic)
{
for (size_t i = 0; i < get_output_size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
return true;
}
else
{
return false;
}
}
......@@ -22,6 +22,7 @@
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <typeindex>
#include <unordered_map>
#include <unordered_set>
......@@ -94,14 +95,21 @@ namespace ngraph
// Called in constructors during transition
void constructor_validate_and_infer_types();
void validate_and_infer_elementwise(element::Type result_type);
void validate_and_infer_elementwise()
{
validate_and_infer_elementwise(get_input_element_type(0));
}
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args();
void validate_and_infer_elementwise_arithmetic();
void validate_and_infer_elementwise_logical();
// Temporary hack while partial shape propagation is being implemented. If any input has
// dynamic shape or dynamic element type, sets all outputs to have a shape of dynamic
// rank and dynamic element type. Ops where we haven't yet implemented partial shape
// propagation can add this boilerplate at the top of their validate_and_infer_types():
//
// if (validate_punt_if_dynamic())
// {
// return;
// }
bool validate_punt_if_dynamic();
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
......@@ -125,7 +133,9 @@ namespace ngraph
return std::type_index(typeid(*this)) == std::type_index(typeid(*n));
}
void set_output_type(size_t i, const element::Type& element_type, const Shape& shape);
void set_output_type(size_t i,
const element::Type& element_type,
const PartialShape& pshape);
bool is_parameter() const;
virtual bool is_output() const;
......
......@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64)
......
......@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
......@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
......
......@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
m_bn_input_shape = get_input_shape(INPUT);
NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
......@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
set_output_size(3);
NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
......
......@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
infer_shape();
Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
......
......@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0);
......
......@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
void op::Convolution::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0);
auto& filters_shape = get_input_shape(1);
......@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
......@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows.
//
......
......@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input,
void op::Dequantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......
......@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0,
void op::Dot::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
......
......@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
......@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto forward_arg_et = get_input_element_type(0);
auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1);
......
......@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg)
constructor_validate_and_infer_types();
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::Not::validate_and_infer_types()
{
validate_and_infer_elementwise();
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
}
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -22,11 +22,11 @@ using namespace std;
using namespace ngraph;
op::Parameter::Parameter(const element::Type& element_type,
const Shape& shape,
const PartialShape& pshape,
const bool cacheable)
: Op("Parameter", {})
, m_cacheable(cacheable)
, m_shape(shape)
, m_partial_shape(pshape)
, m_element_type(element_type)
{
constructor_validate_and_infer_types();
......@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type,
void op::Parameter::validate_and_infer_types()
{
Op::validate_and_infer_types();
set_output_type(0, m_element_type, m_shape);
set_output_type(0, m_element_type, m_partial_shape);
}
shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Parameter>(m_element_type, m_shape);
return make_shared<Parameter>(m_element_type, m_partial_shape);
}
void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -38,10 +38,10 @@ namespace ngraph
/// \brief Constructions a tensor view-typed parameter node.
///
/// \param element_type The element type of the parameter.
/// \param shape The shape of the parameter.
/// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type,
const Shape& shape,
const PartialShape& pshape,
const bool cacheable = false);
void validate_and_infer_types() override;
......@@ -52,7 +52,7 @@ namespace ngraph
protected:
bool m_cacheable;
Shape m_shape;
PartialShape m_partial_shape;
element::Type m_element_type;
};
}
......
......@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input,
void op::Quantize::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
enum
{
INPUT,
......
......@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg,
void op::Reshape::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0);
auto input_shape = input.get_shape();
auto input_rank = input_shape.size();
......
......@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types()
// always borrow the placement conf even the default one
set_placement(get_argument(0)->get_placement());
set_output_type(0, get_input_element_type(0), get_input_shape(0));
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0);
auto input_rank = input_shape.size();
......
......@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_shape(1) << ").";
......
......@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size())
{
m_strides = Strides(m_lower_bounds.size(), 1);
......
......@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg,
void op::TopK::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0);
auto rank = input.get_shape().size();
......
......@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
void op::util::ArithmeticReduction::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0);
for (auto axis : m_reduction_axes)
......
......@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string&
void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{
validate_and_infer_elementwise(element::boolean);
auto args_et_pshape = validate_and_infer_elementwise_args();
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);
}
......@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape)
m_dimensions.assign(shape.begin(), shape.end());
}
bool ngraph::PartialShape::is_complete() const
bool ngraph::PartialShape::is_static() const
{
return m_rank_is_determined &&
std::all_of(m_dimensions.begin(), m_dimensions.end(), [](const Dimension& d) {
return d.is_determined();
});
return m_rank_is_static && std::all_of(m_dimensions.begin(),
m_dimensions.end(),
[](const Dimension& d) { return d.is_static(); });
}
PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{
if (!s1.rank_is_determined() || !s2.rank_is_determined())
if (s1.rank().is_dynamic() || s2.rank().is_dynamic())
{
return PartialShape::undetermined();
return PartialShape::dynamic();
}
if (!s1.rank().compatible(s2.rank()))
......@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
}
PartialShape result{};
result.m_rank_is_determined = true;
result.m_rank_is_static = true;
for (size_t i = 0; i < s1.m_dimensions.size(); i++)
{
result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]);
......@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{
if (shape.m_rank_is_determined)
if (shape.m_rank_is_static)
{
str << "{";
bool first = true;
......@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
bool PartialShape::compatible(const PartialShape& s) const
{
// If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined())
if (!m_rank_is_static || s.rank().is_dynamic())
{
return true;
}
......@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const
}
}
bool PartialShape::same_scheme(const PartialShape& s) const
{
if (rank().is_dynamic() && s.rank().is_dynamic())
{
return true;
}
else if (rank().is_static() && s.rank().is_static())
{
if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
bool success = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
success &= (*this)[i].same_scheme(s[i]);
}
return success;
}
else
{
return false;
}
}
Shape PartialShape::to_shape() const
{
if (!is_complete())
if (is_dynamic())
{
throw std::invalid_argument("to_shape was called on an incomplete shape.");
throw std::invalid_argument("to_shape was called on a dynamic shape.");
}
return Shape(m_dimensions.begin(), m_dimensions.end());
}
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
{
if (dst.rank().is_dynamic())
{
dst = src;
return true;
}
else if (src.rank().is_dynamic())
{
// No change to dst.
return true;
}
else if (size_t(dst.rank()) != size_t(src.rank()))
{
// Mismatching static ranks, cannot merge.
return false;
}
else
{
// Ranks are both static, and they match.
bool success = true;
for (size_t i = 0; i < size_t(dst.rank()); i++)
{
success &= Dimension::merge(dst[i], dst[i], src[i]);
}
return success;
}
}
This diff is collapsed.
......@@ -20,9 +20,8 @@
namespace ngraph
{
/// \brief Alias for "Dimension". Should be used to when the value represents the number of
/// axes in a shape-like object, rather than the size of one dimension in a shape-like
/// object.
/// \brief Alias for Dimension, used when the value represents the number of axes in a shape,
/// rather than the size of one dimension in a shape.
///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension;
......
......@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data);
static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
static json write_dimension(Dimension d)
{
if (d.is_static())
{
return size_t(d);
}
else
{
return nullptr;
}
}
static json write_partial_shape(const PartialShape& s)
{
if (s.rank().is_dynamic())
{
return nullptr;
}
else
{
std::vector<json> vals(size_t(s.rank()));
for (size_t i = 0; i < vals.size(); i++)
{
vals[i] = write_dimension(s[i]);
}
return vals;
}
}
static PartialShape read_partial_shape(const json& j)
{
if (j.is_null())
{
return PartialShape::dynamic();
}
else
{
std::vector<Dimension> dims(j.size());
for (size_t i = 0; i < j.size(); i++)
{
if (j[i].is_null())
{
dims[i] = Dimension::dynamic();
}
else
{
dims[i] = size_t(j[i]);
}
}
return PartialShape(dims);
}
}
static json write_element_type(const ngraph::element::Type& n)
{
json j;
......@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable);
node =
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break;
}
case OP_TYPEID::Power:
......@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Parameter:
{
auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape();
node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type());
break;
......
......@@ -21,7 +21,7 @@
using namespace ngraph;
const element::Type element::unspecified(0, false, false, false, "unspecified");
const element::Type element::dynamic(0, false, false, false, "dynamic");
const element::Type element::boolean(8, false, true, false, "char");
const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double");
......@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
<< ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}";
return out;
}
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{
if (t1.is_dynamic())
{
dst = t2;
return true;
}
else if (t2.is_dynamic())
{
dst = t1;
return true;
}
else if (t1 == t2)
{
dst = t1;
return true;
}
else
{
return false;
}
}
......@@ -33,7 +33,7 @@ namespace ngraph
{
class Type;
extern const Type unspecified;
extern const Type dynamic;
extern const Type boolean;
extern const Type f32;
extern const Type f64;
......@@ -61,6 +61,8 @@ namespace ngraph
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const;
bool is_static() const { return (*this != dynamic); }
bool is_dynamic() const { return !is_static(); }
bool is_real() const { return m_is_real; }
bool is_signed() const { return m_is_signed; }
bool is_quantized() const { return m_is_quantized; }
......@@ -73,12 +75,32 @@ namespace ngraph
/// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; }
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
///
/// To "merge" two element types t1 and t2 is to find the least restrictive
/// element type t that is no more restrictive than t1 and t2, if t exists.
/// More simply:
///
/// merge(dst,element::Type::dynamic,t)
/// writes t to dst and returns true
///
/// merge(dst,t,element::Type::dynamic)
/// writes t to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and equal
/// writes t1 to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and unequal
/// does nothing to dst, and returns false
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
private:
size_t m_bitwidth{0};
bool m_is_real{false};
bool m_is_signed{false};
bool m_is_quantized{false};
std::string m_cname{"unspecified"};
std::string m_cname{"dynamic"};
};
template <typename T>
......
......@@ -84,3 +84,42 @@ TEST(element_type, size)
EXPECT_EQ(2, t1.size());
}
}
TEST(element_type, merge_both_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::dynamic));
ASSERT_TRUE(t.is_dynamic());
}
TEST(element_type, merge_left_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::u64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::u64);
}
TEST(element_type, merge_right_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::i16, element::dynamic));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::i16);
}
TEST(element_type, merge_both_static_equal)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::f64, element::f64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f64);
}
TEST(element_type, merge_both_static_unequal)
{
element::Type t = element::f32;
ASSERT_FALSE(element::Type::merge(t, element::i8, element::i16));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f32);
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment