Commit 0563a3cf authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 3: Framework for partial shape/element type validation (#1728)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup

* Add PartialShape propagation for Parameter and Result

* Implement partial-shape propagation for elementwise ops

* More comments

* One more comment tweak

* Add tests for the merge functions

* Add merging of undetermined element types

* Fix a goophup in deserializer implementation

* Implement fallback for ops that do not support partial shape/type validation

* Updates for some older unit tests, now that operator[] exists

* Add missing validate_punt_if_incomplete to AllReduce

* Better docstrings for the stuff introduced in #1692; remove prototype for unimplemented, unused PartialShape::append()

* One more docstring thing I forgot to save

* Switch terminology from 'determined/undetermined' to 'static/dynamic'

* Switch terminology from 'complete/incomplete' to 'static/dynamic' for shapes; fix up some mushily worded comments

* Fix overzealous edits from the last commit

* Rename one test that escaped the Great Renaming

* Remove unnecessary validate_punt_if_dynamic from Reshape

* Show argument types/shapes in long NodeDescription; tank unit tests to block merge

* Fix dynamic element type propagation for elementwise ops, add some unit tests for same

* Roll 'Not' back to existing behavior (non-boolean input types allowed)

* Add a TODO tag to a todo item
parent 631d7253
...@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -25,7 +25,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const std::string& name) const std::string& name)
: m_element_type(element_type) : m_element_type(element_type)
, m_shape(pshape.is_complete() ? pshape.to_shape() : Shape{}) , m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape) , m_partial_shape(pshape)
, m_name(name) , m_name(name)
{ {
...@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -34,7 +34,7 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape) const PartialShape& pshape)
{ {
if (pshape.is_complete()) if (pshape.is_static())
{ {
m_shape = pshape.to_shape(); m_shape = pshape.to_shape();
} }
...@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type, ...@@ -48,14 +48,14 @@ void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const Shape& descriptor::Tensor::get_shape() const const Shape& descriptor::Tensor::get_shape() const
{ {
if (m_partial_shape.is_complete()) if (m_partial_shape.is_static())
{ {
return m_shape; return m_shape;
} }
else else
{ {
throw std::invalid_argument( throw std::invalid_argument(
"get_shape was called on a descriptor::Tensor with incomplete shape"); "get_shape was called on a descriptor::Tensor with dynamic shape");
} }
} }
......
...@@ -25,11 +25,11 @@ using namespace ngraph; ...@@ -25,11 +25,11 @@ using namespace ngraph;
Dimension::Dimension(size_t dimension) Dimension::Dimension(size_t dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
if (dimension == s_undetermined_val) if (dimension == s_dynamic_val)
{ {
std::stringstream ss; std::stringstream ss;
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_undetermined_val ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_dynamic_val
<< " to Dimension: this value is used internally to represent an undetermined " << " to Dimension: this value is used internally to represent a dynamic "
"dimension."; "dimension.";
throw std::invalid_argument(ss.str()); throw std::invalid_argument(ss.str());
} }
...@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension) ...@@ -37,7 +37,7 @@ Dimension::Dimension(size_t dimension)
std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{ {
if (dimension.is_determined()) if (dimension.is_static())
{ {
return (str << size_t(dimension)); return (str << size_t(dimension));
} }
...@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension) ...@@ -49,11 +49,33 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2) Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
{ {
return (d1.is_determined() && d2.is_determined() ? size_t(d1) + size_t(d2) return (d1.is_static() && d2.is_static() ? size_t(d1) + size_t(d2) : Dimension::dynamic());
: Dimension::undetermined());
} }
bool Dimension::compatible(const Dimension& d) const bool Dimension::compatible(const Dimension& d) const
{ {
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d)); return (is_dynamic() || d.is_dynamic() || m_dimension == size_t(d));
}
bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.is_dynamic())
{
dst = d2;
return true;
}
else if (d2.is_dynamic())
{
dst = d1;
return true;
}
else if (size_t(d1) != size_t(d2))
{
return false;
}
else
{
dst = d1;
return true;
}
} }
...@@ -22,56 +22,97 @@ ...@@ -22,56 +22,97 @@
namespace ngraph namespace ngraph
{ {
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object. /// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
/// ///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is /// Static dimensions may be implicitly converted from size_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::undetermined(). /// constructed with Dimension() or Dimension::dynamic().
/// ///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class Dimension class Dimension
{ {
public: public:
/// \brief Constructs a known dimension. /// \brief Construct a static dimension.
/// /// \param dimension Value of the dimension. Must not be equal to
/// Requires that dimension != s_undetermined_val. If that condition does not hold, /// Dimension::s_dynamic_val.
/// throws std::invalid_argument. /// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension(size_t dimension); Dimension(size_t dimension);
/// \brief Constructs an unknown dimension. /// \brief Construct a dynamic dimension.
Dimension() { m_dimension = s_undetermined_val; } Dimension() { m_dimension = s_dynamic_val; }
/// \brief Returns true if this dimension is determined. /// \brief Check whether this dimension is static.
bool is_determined() const { return m_dimension != s_undetermined_val; } /// \return `true` if the dimension is static, else `false`.
/// \brief Converts this dimension to size_t. If the dimension is undetermined, throws bool is_static() const { return m_dimension != s_dynamic_val; }
/// std::invalid_argument. /// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
bool is_dynamic() const { return !is_static(); }
/// \brief Convert this dimension to `size_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit operator size_t() const explicit operator size_t() const
{ {
if (!is_determined()) if (is_dynamic())
{ {
throw std::invalid_argument("Cannot convert unknown dimension to size_t"); throw std::invalid_argument("Cannot convert dynamic dimension to size_t");
} }
return m_dimension; return m_dimension;
} }
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions /// \brief Check whether this dimension represents the same scheme as the argument (both
/// is undetermined, or both dimensions are determined and equal. /// dynamic, or equal).
/// \param dim The other dimension to compare this dimension to.
/// \return `true` if this dimension and `dim` are both dynamic, or if they are both
/// static and equal; otherwise, `false`.
bool same_scheme(const Dimension& dim) const
{
return (is_dynamic() && dim.is_dynamic()) ||
(is_static() && dim.is_static() && m_dimension == size_t(dim));
}
/// \brief Try to merge two Dimension objects together.
/// \param[out] dst Reference to write the merged Dimension into.
/// \param d1 First dimension to merge.
/// \param d2 Second dimension to merge.
/// \return `true` if merging succeeds, else `false`.
///
/// \li If `d1` is dynamic, writes `d2` to `dst` and returns `true`.
/// \li If `d2` is dynamic, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are static and equal, writes `d1` to `dst` and returns `true`.
/// \li If `d1` and `d2` are both static and unequal, leaves `dst` unchanged and
/// returns `false`.
static bool merge(Dimension& dst, const Dimension d1, const Dimension d2);
/// \brief Check whether this dimension is capable of being merged with the argument
/// dimension.
/// \param d The dimension to compare this dimension with.
/// \return `true` if this dimension is compatible with `d`, else `false`.
///
/// Two dimensions are considered compatible if it is possible to merge them. (See
/// Dimension::merge.)
bool compatible(const Dimension& d) const; bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension. /// \brief Create a dynamic dimension.
static Dimension undetermined() { return Dimension(); } /// \return A dynamic dimension.
/// \brief Constant for the value used internally to represent an unknown dimension. static Dimension dynamic() { return Dimension(); }
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()}; /// \brief Constant for the value used internally to represent a dynamic dimension.
static const size_t s_dynamic_val{std::numeric_limits<size_t>::max()};
private: private:
// The actual numerical value of the dimension. s_undetermined_val is a special case, // The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing an unknown dimension. // representing a dynamic dimension.
size_t m_dimension; size_t m_dimension;
}; };
/// \brief Inserts a human-readable representation of "dimension" into "str". /// \brief Insert a human-readable representation of a dimension into an output stream.
/// \param str The output stream targeted for insertion.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `size_t(dimension)`.
std::ostream& operator<<(std::ostream& str, const Dimension& dimension); std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions. /// \brief Addition operator for dimensions.
/// /// \param d1 Left operand for addition.
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns /// \param d2 Right operand for addition.
/// Dimension::undetermined(). /// \return Dimension::dynamic() if either of `d1` or `d2` is dynamic; else, a static
/// dimension with value `size_t(d1)+size_t(d2)`.
Dimension operator+(const Dimension& d1, const Dimension& d2); Dimension operator+(const Dimension& d1, const Dimension& d2);
} }
...@@ -52,9 +52,7 @@ namespace ngraph ...@@ -52,9 +52,7 @@ namespace ngraph
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break; case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break; case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break; case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED: case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::dynamic; break;
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type"; default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
} }
......
...@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n) ...@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n)
for (size_t i = m_outputs.size(); i < n; ++i) for (size_t i = m_outputs.size(); i < n; ++i)
{ {
auto tensor_descriptor = make_shared<descriptor::Tensor>( auto tensor_descriptor = make_shared<descriptor::Tensor>(
element::unspecified, Shape(), get_name() + "_" + to_string(i)); element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_descriptor); m_outputs.emplace_back(this, i, tensor_descriptor);
} }
} }
...@@ -84,9 +84,9 @@ void Node::validate_and_infer_types() ...@@ -84,9 +84,9 @@ void Node::validate_and_infer_types()
{ {
} }
void Node::set_output_type(size_t i, const element::Type& element_type, const Shape& shape) void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{ {
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, shape); m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
} }
std::deque<descriptor::Output>& Node::get_outputs() std::deque<descriptor::Output>& Node::get_outputs()
...@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const ...@@ -208,13 +208,27 @@ std::ostream& Node::write_short_description(std::ostream& out) const
return out << get_name(); return out << get_name();
} }
static std::string pretty_element_type(const element::Type& et)
{
if (et.is_dynamic())
{
return "?";
}
else
{
return et.c_type_string();
}
}
std::ostream& Node::write_long_description(std::ostream& out) const std::ostream& Node::write_long_description(std::ostream& out) const
{ {
out << description() << '[' << get_name() << "]("; out << description() << '[' << get_name() << "](";
string sep = ""; string sep = "";
for (auto arg : get_arguments()) for (auto arg : get_arguments())
{ {
out << sep << NodeDescription(*arg, true); out << sep << NodeDescription(*arg, true) << ": "
<< pretty_element_type(arg->get_output_element_type(0))
<< arg->get_output_partial_shape(0) << "";
sep = ", "; sep = ", ";
} }
out << ")"; out << ")";
...@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args) ...@@ -404,39 +418,73 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args; return args;
} }
void Node::validate_and_infer_elementwise(element::Type result_type) std::tuple<element::Type, PartialShape> Node::validate_and_infer_elementwise_args()
{ {
const element::Type& element_type = get_input_element_type(0); element::Type element_type = get_input_element_type(0);
const Shape& shape = get_input_shape(0); PartialShape pshape = get_input_partial_shape(0);
if (get_input_size() > 1) if (get_input_size() > 1)
{ {
for (size_t i = 1; i < get_input_size(); ++i) for (size_t i = 1; i < get_input_size(); ++i)
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(i) == element_type) NODE_VALIDATION_ASSERT(
<< "Argument 0 element type " << element_type this, element::Type::merge(element_type, element_type, get_input_element_type(i)))
<< " differs in element type from argument " << i << " " << *get_argument(i) << "Argument element types are inconsistent.";
<< " element type " << get_input_element_type(i);
NODE_VALIDATION_ASSERT(this,
NODE_VALIDATION_ASSERT(this, get_input_shape(i) == shape) PartialShape::merge_into(pshape, get_input_partial_shape(i)))
<< "Argument 0 shape " << shape << " differs in shape from argument " << i << " " << "Argument shapes are inconsistent.";
<< *get_argument(i) << " shape " << get_input_shape(i);
} }
} }
set_output_type(0, result_type, shape);
return std::make_tuple(element_type, pshape);
} }
void Node::validate_and_infer_elementwise_arithmetic() void Node::validate_and_infer_elementwise_arithmetic()
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) != element::boolean) auto args_et_pshape = validate_and_infer_elementwise_args();
<< "Arguments cannot have boolean element type (argument element type: " element::Type& args_et = std::get<0>(args_et_pshape);
<< get_input_element_type(0) << ")."; PartialShape& args_pshape = std::get<1>(args_et_pshape);
validate_and_infer_elementwise(get_input_element_type(0));
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et != element::boolean)
<< "Arguments cannot have boolean element type (argument element type: " << args_et << ").";
set_output_type(0, args_et, args_pshape);
} }
void Node::validate_and_infer_elementwise_logical() void Node::validate_and_infer_elementwise_logical()
{ {
NODE_VALIDATION_ASSERT(this, get_input_element_type(0) == element::boolean) auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_ASSERT(this, args_et.is_dynamic() || args_et == element::boolean)
<< "Operands for logical operators must have boolean element type but have element type " << "Operands for logical operators must have boolean element type but have element type "
<< get_input_element_type(0) << "."; << args_et << ".";
validate_and_infer_elementwise(get_input_element_type(0));
set_output_type(0, element::boolean, args_pshape);
}
bool Node::validate_punt_if_dynamic()
{
bool any_dynamic = false;
for (auto& input : m_inputs)
{
any_dynamic |= input.get_partial_shape().is_dynamic();
any_dynamic |= input.get_element_type().is_dynamic();
}
if (any_dynamic)
{
for (size_t i = 0; i < get_output_size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
return true;
}
else
{
return false;
}
} }
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <tuple>
#include <typeindex> #include <typeindex>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -94,14 +95,21 @@ namespace ngraph ...@@ -94,14 +95,21 @@ namespace ngraph
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
void validate_and_infer_elementwise(element::Type result_type); std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args();
void validate_and_infer_elementwise()
{
validate_and_infer_elementwise(get_input_element_type(0));
}
void validate_and_infer_elementwise_arithmetic(); void validate_and_infer_elementwise_arithmetic();
void validate_and_infer_elementwise_logical(); void validate_and_infer_elementwise_logical();
// Temporary hack while partial shape propagation is being implemented. If any input has
// dynamic shape or dynamic element type, sets all outputs to have a shape of dynamic
// rank and dynamic element type. Ops where we haven't yet implemented partial shape
// propagation can add this boilerplate at the top of their validate_and_infer_types():
//
// if (validate_punt_if_dynamic())
// {
// return;
// }
bool validate_punt_if_dynamic();
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1); Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {} virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
...@@ -125,7 +133,9 @@ namespace ngraph ...@@ -125,7 +133,9 @@ namespace ngraph
return std::type_index(typeid(*this)) == std::type_index(typeid(*n)); return std::type_index(typeid(*this)) == std::type_index(typeid(*n));
} }
void set_output_type(size_t i, const element::Type& element_type, const Shape& shape); void set_output_type(size_t i,
const element::Type& element_type,
const PartialShape& pshape);
bool is_parameter() const; bool is_parameter() const;
virtual bool is_output() const; virtual bool is_output() const;
......
...@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg) ...@@ -27,6 +27,11 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types() void op::AllReduce::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, NODE_VALIDATION_ASSERT(this,
get_input_element_type(0) == element::f32 || get_input_element_type(0) == element::f32 ||
get_input_element_type(0) == element::f64) get_input_element_type(0) == element::f64)
......
...@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg, ...@@ -40,6 +40,11 @@ op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
void op::AvgPool::validate_and_infer_types() void op::AvgPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
...@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -120,6 +125,11 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
void op::AvgPoolBackprop::validate_and_infer_types() void op::AvgPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& delta_shape = get_input_shape(0); auto& delta_shape = get_input_shape(0);
// infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for // infer_batched_forward_pooling wants CoordinateDiffs for these, while the pooling ops for
......
...@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -35,6 +35,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
void ngraph::op::BatchNorm::validate_and_infer_types() void ngraph::op::BatchNorm::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
m_bn_input_shape = get_input_shape(INPUT); m_bn_input_shape = get_input_shape(INPUT);
NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2) NODE_VALIDATION_ASSERT(this, m_bn_input_shape.size() >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape << "Input argument must have rank of at least 2 (input argument shape: " << m_bn_input_shape
...@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps, ...@@ -158,6 +163,11 @@ ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
void ngraph::op::BatchNormBackprop::validate_and_infer_types() void ngraph::op::BatchNormBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
set_output_size(3); set_output_size(3);
NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4) NODE_VALIDATION_ASSERT(this, get_input_shape(INPUT).size() == 4)
......
...@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg, ...@@ -40,6 +40,11 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void op::Broadcast::validate_and_infer_types() void op::Broadcast::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
infer_shape(); infer_shape();
Shape target_shape = m_shape; Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
......
...@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) ...@@ -32,6 +32,11 @@ op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required."; NODE_VALIDATION_ASSERT(this, m_inputs.size() >= 1) << "At least one argument required.";
Shape first_input_shape = get_input_shape(0); Shape first_input_shape = get_input_shape(0);
......
...@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -46,6 +46,11 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
void op::Convolution::validate_and_infer_types() void op::Convolution::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& data_batch_shape = get_input_shape(0); auto& data_batch_shape = get_input_shape(0);
auto& data_batch_et = get_input_element_type(0); auto& data_batch_et = get_input_element_type(0);
auto& filters_shape = get_input_shape(1); auto& filters_shape = get_input_shape(1);
...@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha ...@@ -220,6 +225,11 @@ op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_sha
void op::ConvolutionBackpropData::validate_and_infer_types() void op::ConvolutionBackpropData::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to data is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
...@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( ...@@ -410,6 +420,11 @@ op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
void op::ConvolutionBackpropFilters::validate_and_infer_types() void op::ConvolutionBackpropFilters::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
// Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as // Backprop to filters is itself convolution, with inputs/outputs/attributes transmogrified as
// follows. // follows.
// //
......
...@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input, ...@@ -34,6 +34,11 @@ op::Dequantize::Dequantize(shared_ptr<Node> input,
void op::Dequantize::validate_and_infer_types() void op::Dequantize::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
enum enum
{ {
INPUT, INPUT,
......
...@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0, ...@@ -47,6 +47,11 @@ op::Dot::Dot(const shared_ptr<Node>& arg0,
void op::Dot::validate_and_infer_types() void op::Dot::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1); auto& input_1 = get_inputs().at(1);
......
...@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -42,6 +42,11 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
void op::MaxPool::validate_and_infer_types() void op::MaxPool::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3) NODE_VALIDATION_ASSERT(this, arg_shape.size() >= 3)
...@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -120,6 +125,11 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
void op::MaxPoolBackprop::validate_and_infer_types() void op::MaxPoolBackprop::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto forward_arg_et = get_input_element_type(0); auto forward_arg_et = get_input_element_type(0);
auto& forward_arg_shape = get_input_shape(0); auto& forward_arg_shape = get_input_shape(0);
auto delta_et = get_input_element_type(1); auto delta_et = get_input_element_type(1);
......
...@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg) ...@@ -26,9 +26,14 @@ op::Not::Not(const shared_ptr<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::Not::validate_and_infer_types() void op::Not::validate_and_infer_types()
{ {
validate_and_infer_elementwise(); auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
} }
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -22,11 +22,11 @@ using namespace std; ...@@ -22,11 +22,11 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::Parameter::Parameter(const element::Type& element_type, op::Parameter::Parameter(const element::Type& element_type,
const Shape& shape, const PartialShape& pshape,
const bool cacheable) const bool cacheable)
: Op("Parameter", {}) : Op("Parameter", {})
, m_cacheable(cacheable) , m_cacheable(cacheable)
, m_shape(shape) , m_partial_shape(pshape)
, m_element_type(element_type) , m_element_type(element_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type, ...@@ -35,13 +35,13 @@ op::Parameter::Parameter(const element::Type& element_type,
void op::Parameter::validate_and_infer_types() void op::Parameter::validate_and_infer_types()
{ {
Op::validate_and_infer_types(); Op::validate_and_infer_types();
set_output_type(0, m_element_type, m_shape); set_output_type(0, m_element_type, m_partial_shape);
} }
shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Parameter::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Parameter>(m_element_type, m_shape); return make_shared<Parameter>(m_element_type, m_partial_shape);
} }
void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -38,10 +38,10 @@ namespace ngraph ...@@ -38,10 +38,10 @@ namespace ngraph
/// \brief Constructions a tensor view-typed parameter node. /// \brief Constructions a tensor view-typed parameter node.
/// ///
/// \param element_type The element type of the parameter. /// \param element_type The element type of the parameter.
/// \param shape The shape of the parameter. /// \param pshape The partial shape of the parameter.
/// \param cacheable True if the parameter is not expected to be frequently updated. /// \param cacheable True if the parameter is not expected to be frequently updated.
Parameter(const ngraph::element::Type& element_type, Parameter(const ngraph::element::Type& element_type,
const Shape& shape, const PartialShape& pshape,
const bool cacheable = false); const bool cacheable = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -52,7 +52,7 @@ namespace ngraph ...@@ -52,7 +52,7 @@ namespace ngraph
protected: protected:
bool m_cacheable; bool m_cacheable;
Shape m_shape; PartialShape m_partial_shape;
element::Type m_element_type; element::Type m_element_type;
}; };
} }
......
...@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input, ...@@ -36,6 +36,11 @@ op::Quantize::Quantize(shared_ptr<Node> input,
void op::Quantize::validate_and_infer_types() void op::Quantize::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
enum enum
{ {
INPUT, INPUT,
......
...@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg, ...@@ -35,6 +35,11 @@ op::Reshape::Reshape(const shared_ptr<Node>& arg,
void op::Reshape::validate_and_infer_types() void op::Reshape::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
......
...@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types() ...@@ -37,7 +37,7 @@ void op::Result::validate_and_infer_types()
// always borrow the placement conf even the default one // always borrow the placement conf even the default one
set_placement(get_argument(0)->get_placement()); set_placement(get_argument(0)->get_placement());
set_output_type(0, get_input_element_type(0), get_input_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes) ...@@ -32,6 +32,11 @@ op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
void op::Reverse::validate_and_infer_types() void op::Reverse::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0); auto input_shape = get_input_shape(0);
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
......
...@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, ...@@ -38,6 +38,11 @@ op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
void op::ReverseSequence::validate_and_infer_types() void op::ReverseSequence::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1) NODE_VALIDATION_ASSERT(this, get_input_shape(1).size() == 1)
<< "Sequence indices must be a 1-dimensional tensor (sequence indices shape: " << "Sequence indices must be a 1-dimensional tensor (sequence indices shape: "
<< get_input_shape(1) << ")."; << get_input_shape(1) << ").";
......
...@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg, ...@@ -44,6 +44,11 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void op::Slice::validate_and_infer_types() void op::Slice::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
if (0 == m_strides.size()) if (0 == m_strides.size())
{ {
m_strides = Strides(m_lower_bounds.size(), 1); m_strides = Strides(m_lower_bounds.size(), 1);
......
...@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg, ...@@ -39,6 +39,11 @@ op::TopK::TopK(const shared_ptr<Node>& arg,
void op::TopK::validate_and_infer_types() void op::TopK::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
auto rank = input.get_shape().size(); auto rank = input.get_shape().size();
......
...@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type, ...@@ -29,6 +29,11 @@ op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
{ {
if (validate_punt_if_dynamic())
{
return;
}
auto input_shape = get_input_shape(0); auto input_shape = get_input_shape(0);
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
......
...@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& ...@@ -28,5 +28,8 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string&
void op::util::BinaryElementwiseComparison::validate_and_infer_types() void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{ {
validate_and_infer_elementwise(element::boolean); auto args_et_pshape = validate_and_infer_elementwise_args();
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);
} }
...@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape) ...@@ -28,19 +28,18 @@ PartialShape::PartialShape(const Shape& shape)
m_dimensions.assign(shape.begin(), shape.end()); m_dimensions.assign(shape.begin(), shape.end());
} }
bool ngraph::PartialShape::is_complete() const bool ngraph::PartialShape::is_static() const
{ {
return m_rank_is_determined && return m_rank_is_static && std::all_of(m_dimensions.begin(),
std::all_of(m_dimensions.begin(), m_dimensions.end(), [](const Dimension& d) { m_dimensions.end(),
return d.is_determined(); [](const Dimension& d) { return d.is_static(); });
});
} }
PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{ {
if (!s1.rank_is_determined() || !s2.rank_is_determined()) if (s1.rank().is_dynamic() || s2.rank().is_dynamic())
{ {
return PartialShape::undetermined(); return PartialShape::dynamic();
} }
if (!s1.rank().compatible(s2.rank())) if (!s1.rank().compatible(s2.rank()))
...@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) ...@@ -49,7 +48,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
} }
PartialShape result{}; PartialShape result{};
result.m_rank_is_determined = true; result.m_rank_is_static = true;
for (size_t i = 0; i < s1.m_dimensions.size(); i++) for (size_t i = 0; i < s1.m_dimensions.size(); i++)
{ {
result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]); result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]);
...@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2) ...@@ -59,7 +58,7 @@ PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape) std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{ {
if (shape.m_rank_is_determined) if (shape.m_rank_is_static)
{ {
str << "{"; str << "{";
bool first = true; bool first = true;
...@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape) ...@@ -83,7 +82,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
bool PartialShape::compatible(const PartialShape& s) const bool PartialShape::compatible(const PartialShape& s) const
{ {
// If we don't know *this's rank, or we don't know s's rank, they are compatible. // If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined()) if (!m_rank_is_static || s.rank().is_dynamic())
{ {
return true; return true;
} }
...@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const ...@@ -109,12 +108,69 @@ bool PartialShape::compatible(const PartialShape& s) const
} }
} }
bool PartialShape::same_scheme(const PartialShape& s) const
{
if (rank().is_dynamic() && s.rank().is_dynamic())
{
return true;
}
else if (rank().is_static() && s.rank().is_static())
{
if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
bool success = true;
for (size_t i = 0; i < size_t(rank()); i++)
{
success &= (*this)[i].same_scheme(s[i]);
}
return success;
}
else
{
return false;
}
}
Shape PartialShape::to_shape() const Shape PartialShape::to_shape() const
{ {
if (!is_complete()) if (is_dynamic())
{ {
throw std::invalid_argument("to_shape was called on an incomplete shape."); throw std::invalid_argument("to_shape was called on a dynamic shape.");
} }
return Shape(m_dimensions.begin(), m_dimensions.end()); return Shape(m_dimensions.begin(), m_dimensions.end());
} }
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
{
if (dst.rank().is_dynamic())
{
dst = src;
return true;
}
else if (src.rank().is_dynamic())
{
// No change to dst.
return true;
}
else if (size_t(dst.rank()) != size_t(src.rank()))
{
// Mismatching static ranks, cannot merge.
return false;
}
else
{
// Ranks are both static, and they match.
bool success = true;
for (size_t i = 0; i < size_t(dst.rank()); i++)
{
success &= Dimension::merge(dst[i], dst[i], src[i]);
}
return success;
}
}
This diff is collapsed.
...@@ -20,9 +20,8 @@ ...@@ -20,9 +20,8 @@
namespace ngraph namespace ngraph
{ {
/// \brief Alias for "Dimension". Should be used to when the value represents the number of /// \brief Alias for Dimension, used when the value represents the number of axes in a shape,
/// axes in a shape-like object, rather than the size of one dimension in a shape-like /// rather than the size of one dimension in a shape.
/// object.
/// ///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. /// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension; using Rank = Dimension;
......
...@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data); ...@@ -149,6 +149,59 @@ static json write(const ngraph::Node&, bool binary_constant_data);
static string static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data); serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
static json write_dimension(Dimension d)
{
if (d.is_static())
{
return size_t(d);
}
else
{
return nullptr;
}
}
static json write_partial_shape(const PartialShape& s)
{
if (s.rank().is_dynamic())
{
return nullptr;
}
else
{
std::vector<json> vals(size_t(s.rank()));
for (size_t i = 0; i < vals.size(); i++)
{
vals[i] = write_dimension(s[i]);
}
return vals;
}
}
static PartialShape read_partial_shape(const json& j)
{
if (j.is_null())
{
return PartialShape::dynamic();
}
else
{
std::vector<Dimension> dims(j.size());
for (size_t i = 0; i < j.size(); i++)
{
if (j[i].is_null())
{
dims[i] = Dimension::dynamic();
}
else
{
dims[i] = size_t(j[i]);
}
}
return PartialShape(dims);
}
}
static json write_element_type(const ngraph::element::Type& n) static json write_element_type(const ngraph::element::Type& n)
{ {
json j; json j;
...@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function> ...@@ -839,7 +892,8 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false); auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = make_shared<op::Parameter>(element_type, shape, cacheable); node =
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break; break;
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
...@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1389,7 +1443,7 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Parameter: case OP_TYPEID::Parameter:
{ {
auto tmp = dynamic_cast<const op::Parameter*>(&n); auto tmp = dynamic_cast<const op::Parameter*>(&n);
node["shape"] = tmp->get_shape(); node["shape"] = write_partial_shape(tmp->get_output_partial_shape(0));
node["cacheable"] = tmp->get_cacheable(); node["cacheable"] = tmp->get_cacheable();
node["element_type"] = write_element_type(tmp->get_element_type()); node["element_type"] = write_element_type(tmp->get_element_type());
break; break;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using namespace ngraph; using namespace ngraph;
const element::Type element::unspecified(0, false, false, false, "unspecified"); const element::Type element::dynamic(0, false, false, false, "dynamic");
const element::Type element::boolean(8, false, true, false, "char"); const element::Type element::boolean(8, false, true, false, "char");
const element::Type element::f32(32, true, true, false, "float"); const element::Type element::f32(32, true, true, false, "float");
const element::Type element::f64(64, true, true, false, "double"); const element::Type element::f64(64, true, true, false, "double");
...@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj) ...@@ -184,3 +184,26 @@ std::ostream& element::operator<<(std::ostream& out, const element::Type& obj)
<< ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}"; << ", " << obj.m_is_quantized << ", \"" << obj.m_cname << "\"}";
return out; return out;
} }
bool element::Type::merge(element::Type& dst, const element::Type& t1, const element::Type& t2)
{
if (t1.is_dynamic())
{
dst = t2;
return true;
}
else if (t2.is_dynamic())
{
dst = t1;
return true;
}
else if (t1 == t2)
{
dst = t1;
return true;
}
else
{
return false;
}
}
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
{ {
class Type; class Type;
extern const Type unspecified; extern const Type dynamic;
extern const Type boolean; extern const Type boolean;
extern const Type f32; extern const Type f32;
extern const Type f64; extern const Type f64;
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
const std::string& c_type_string() const; const std::string& c_type_string() const;
size_t size() const; size_t size() const;
size_t hash() const; size_t hash() const;
bool is_static() const { return (*this != dynamic); }
bool is_dynamic() const { return !is_static(); }
bool is_real() const { return m_is_real; } bool is_real() const { return m_is_real; }
bool is_signed() const { return m_is_signed; } bool is_signed() const { return m_is_signed; }
bool is_quantized() const { return m_is_quantized; } bool is_quantized() const { return m_is_quantized; }
...@@ -73,12 +75,32 @@ namespace ngraph ...@@ -73,12 +75,32 @@ namespace ngraph
/// Returns true if the type is floating point, else false. /// Returns true if the type is floating point, else false.
bool get_is_real() const { return m_is_real; } bool get_is_real() const { return m_is_real; }
/// \brief Merges two element types t1 and t2, writing the result into dst and
/// returning true if successful, else returning false.
///
/// To "merge" two element types t1 and t2 is to find the least restrictive
/// element type t that is no more restrictive than t1 and t2, if t exists.
/// More simply:
///
/// merge(dst,element::Type::dynamic,t)
/// writes t to dst and returns true
///
/// merge(dst,t,element::Type::dynamic)
/// writes t to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and equal
/// writes t1 to dst and returns true
///
/// merge(dst,t1,t2) where t1, t2 both static and unequal
/// does nothing to dst, and returns false
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
private: private:
size_t m_bitwidth{0}; size_t m_bitwidth{0};
bool m_is_real{false}; bool m_is_real{false};
bool m_is_signed{false}; bool m_is_signed{false};
bool m_is_quantized{false}; bool m_is_quantized{false};
std::string m_cname{"unspecified"}; std::string m_cname{"dynamic"};
}; };
template <typename T> template <typename T>
......
...@@ -84,3 +84,42 @@ TEST(element_type, size) ...@@ -84,3 +84,42 @@ TEST(element_type, size)
EXPECT_EQ(2, t1.size()); EXPECT_EQ(2, t1.size());
} }
} }
TEST(element_type, merge_both_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::dynamic));
ASSERT_TRUE(t.is_dynamic());
}
TEST(element_type, merge_left_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::dynamic, element::u64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::u64);
}
TEST(element_type, merge_right_dynamic)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::i16, element::dynamic));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::i16);
}
TEST(element_type, merge_both_static_equal)
{
element::Type t;
ASSERT_TRUE(element::Type::merge(t, element::f64, element::f64));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f64);
}
TEST(element_type, merge_both_static_unequal)
{
element::Type t = element::f32;
ASSERT_FALSE(element::Type::merge(t, element::i8, element::i16));
ASSERT_TRUE(t.is_static());
ASSERT_EQ(t, element::f32);
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment