Commit 5f5c1843 authored by Adam Procter's avatar Adam Procter

Cleanup and some comments

parent bfd9655f
...@@ -48,5 +48,3 @@ bool ngraph::operator!=(const Dimension& d1, const Dimension& d2) ...@@ -48,5 +48,3 @@ bool ngraph::operator!=(const Dimension& d1, const Dimension& d2)
{ {
return (d1.is_determined() && d2.is_determined() && size_t(d1) != size_t(d2)); return (d1.is_determined() && d2.is_determined() && size_t(d1) != size_t(d2));
} }
const Dimension& Dimension::s_undetermined{};
...@@ -23,28 +23,59 @@ ...@@ -23,28 +23,59 @@
namespace ngraph namespace ngraph
{ {
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object.
///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is
/// constructed with Dimension() or Dimension::undetermined().
class Dimension class Dimension
{ {
public: public:
/// \brief Constructs a known dimension.
Dimension(size_t dimension) Dimension(size_t dimension)
: m_dimension(dimension) : m_dimension(dimension)
{ {
} }
/// \brief Constructs an unknown dimension.
Dimension() Dimension()
: m_dimension(s_undetermined_val) : m_dimension(s_undetermined_val)
{ {
} }
/// \brief Returns true if this dimension is known.
bool is_determined() const { return m_dimension != s_undetermined_val; } bool is_determined() const { return m_dimension != s_undetermined_val; }
/// \brief Converts this dimension to size_t. If the dimension is unknown, behavior is
/// undefined.
explicit operator size_t() const { return m_dimension; } explicit operator size_t() const { return m_dimension; }
static const Dimension& undetermined() { return s_undetermined; } /// \brief Constructs an unknown dimension.
static Dimension undetermined() { return s_undetermined_val; }
private: private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
size_t m_dimension; size_t m_dimension;
static const Dimension& s_undetermined;
// Constant for the size_t value used to represent an unknown dimension.
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()}; static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()};
}; };
/// \brief Pushes a human-readable representation of "dimension" onto "str".
std::ostream& operator<<(std::ostream& str, const Dimension& dimension); std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns
/// Dimension::undetermined().
Dimension operator+(const Dimension& d1, const Dimension& d2); Dimension operator+(const Dimension& d1, const Dimension& d2);
/// \brief Equality operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)==size_t(d2). Otherwise, returns
/// false.
bool operator==(const Dimension& d1, const Dimension& d2); bool operator==(const Dimension& d1, const Dimension& d2);
/// \brief Inequality operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)!=size_t(d2). Otherwise, returns
/// false.
bool operator!=(const Dimension& d1, const Dimension& d2); bool operator!=(const Dimension& d1, const Dimension& d2);
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE. //
#pragma once #pragma once
...@@ -25,33 +25,79 @@ ...@@ -25,33 +25,79 @@
namespace ngraph namespace ngraph
{ {
/// \brief Class representing a shape that may only be partially known.
///
/// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
///
/// A partially-known shape may have:
///
/// - Unknown rank.
/// - Known rank, but unknown dimensions on some or all axes.
/// - Known rank, and known dimensions on all axes.
class PartialShape class PartialShape
{ {
public: public:
/// \brief Constructs a shape with undetermined rank.
///
/// Examples:
///
/// PartialShape s{2,3,4}; // rank=3, all dimensions determined
/// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::undetermined(),3}; // rank=2, dimension 1 undetermined
PartialShape(std::initializer_list<Dimension> init) PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init) : PartialShape(true, init)
{ {
} }
/// \brief Returns true if the shape has determined rank.
bool rank_is_determined() const { return m_rank_is_determined; } bool rank_is_determined() const { return m_rank_is_determined; }
/// \brief Returns true if the shape has known rank and all dimensions of the shape
/// are determined.
bool is_complete() const; bool is_complete() const;
/// \brief Returns the rank of the shape. Returns Rank::undetermined() if the rank is undetermined.
Rank rank() const Rank rank() const
{ {
return m_rank_is_determined ? Rank(m_dimensions.size()) : Rank::undetermined(); return m_rank_is_determined ? Rank(m_dimensions.size()) : Rank::undetermined();
} }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2); /// \brief Appends another shape to this shape.
///
/// If "this" and "other" both have determined rank, returns a new shape two shape
/// whose dimensions are the concatenation of the dimensions of "this" and "other".
/// If either "this" or "other" has undetermined rank, returns
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other); PartialShape append(const PartialShape& other);
/// \brief Returns the undetermined shape.
static PartialShape undetermined() { return PartialShape(false, {}); } static PartialShape undetermined() { return PartialShape(false, {}); }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
private: private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape.
PartialShape(bool rank_is_determined, std::initializer_list<Dimension> init) PartialShape(bool rank_is_determined, std::initializer_list<Dimension> init)
: m_rank_is_determined(rank_is_determined) : m_rank_is_determined(rank_is_determined)
, m_dimensions(init) , m_dimensions(init)
{ {
} }
// True if the shape's rank is determined.
bool m_rank_is_determined; bool m_rank_is_determined;
// Shape dimensions. This has no meaning if m_rank_is_determined is false.
std::vector<Dimension> m_dimensions; std::vector<Dimension> m_dimensions;
}; };
/// \brief Elementwise addition of two shapes.
///
/// If s1 or s2 has undetermined rank, returns PartialShape::undetermined().
/// If s1 and s2 both have determined rank, and their ranks are unequal,
/// throws std::invalid_argument.
/// If s1 and s2 both have determined rank, and their ranks are equal,
/// returns a new shape whose ith dimension is s1[i] + s2[i].
PartialShape operator+(const PartialShape& s1, const PartialShape& s2); PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Pushes a human-readable representation of "shape" onto "str".
std::ostream& operator<<(std::ostream& str, const PartialShape& shape); std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
} }
...@@ -20,5 +20,10 @@ ...@@ -20,5 +20,10 @@
namespace ngraph namespace ngraph
{ {
/// \brief Alias for "Dimension". Should be used to when the value represents the number of
/// axes in a shape-like object, rather than the size of one dimension in a shape-like
/// object.
///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension; using Rank = Dimension;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment