Commit a0be5231 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Partial Shapes, Part 2: Adapt Tensor class to have partial shapes (#1718)

* Adapt Tensor class to have partial shapes

* Add PartialShapes to Input, Output, Function, Node classes

* Terminological cleanup
parent 780b56bf
......@@ -84,6 +84,11 @@ const Shape& Input::get_shape() const
return m_output->get_shape();
}
const PartialShape& Input::get_partial_shape() const
{
return m_output->get_partial_shape();
}
const element::Type& Input::get_element_type() const
{
return m_output->get_element_type();
......
......@@ -68,6 +68,9 @@ namespace ngraph
/// \return the shape of the connected output
const Shape& get_shape() const;
/// \return the partial shape of the connected output
const PartialShape& get_partial_shape() const;
/// \return the element type of the connected output
const element::Type& get_element_type() const;
......
......@@ -54,6 +54,11 @@ const Shape& descriptor::Output::get_shape() const
return m_tensor->get_shape();
}
const PartialShape& descriptor::Output::get_partial_shape() const
{
return m_tensor->get_partial_shape();
}
const element::Type& descriptor::Output::get_element_type() const
{
return m_tensor->get_element_type();
......
......@@ -53,6 +53,10 @@ namespace ngraph
/// \return the shape of the output
const Shape& get_shape() const;
/// \return the partial shape of the output
const PartialShape& get_partial_shape() const;
/// \return the element type of the output
const element::Type& get_element_type() const;
......
......@@ -22,20 +22,43 @@ using namespace ngraph;
using namespace std;
descriptor::Tensor::Tensor(const element::Type& element_type,
const Shape& shape,
const PartialShape& pshape,
const std::string& name)
: m_element_type(element_type)
, m_shape(shape)
, m_shape(pshape.is_complete() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_name(name)
{
}
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, const Shape& shape)
void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape)
{
m_shape = shape;
if (pshape.is_complete())
{
m_shape = pshape.to_shape();
}
else
{
m_shape = Shape{};
}
m_partial_shape = pshape;
m_element_type = element_type;
}
const Shape& descriptor::Tensor::get_shape() const
{
if (m_partial_shape.is_complete())
{
return m_shape;
}
else
{
throw std::invalid_argument(
"get_shape was called on a descriptor::Tensor with incomplete shape");
}
}
void descriptor::Tensor::set_pool_offset(size_t offset)
{
m_pool_offset = offset;
......
......@@ -20,6 +20,7 @@
#include <string>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -41,13 +42,16 @@ namespace ngraph
Tensor& operator=(const Tensor&) = delete;
public:
Tensor(const element::Type& element_type, const Shape& shape, const std::string& name);
Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name);
const std::string& get_name() const { return m_name; }
void set_tensor_type(const element::Type& element_type, const Shape& shape);
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; }
const Shape& get_shape() const;
const PartialShape& get_partial_shape() const { return m_partial_shape; }
const std::shared_ptr<layout::TensorLayout>& get_tensor_layout() const
{
return m_tensor_layout;
......@@ -62,7 +66,14 @@ namespace ngraph
protected:
element::Type m_element_type;
// TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
// with m_shape possibly being invalid (get_shape will throw an exception if it
// is). This is because get_shape() returns a const reference. I think ideally we
// should refactor so that get_shape returns by value.
Shape m_shape;
PartialShape m_partial_shape;
std::string m_name;
std::shared_ptr<layout::TensorLayout> m_tensor_layout;
size_t m_pool_offset{0};
......
......@@ -161,6 +161,11 @@ const Shape& Function::get_output_shape(size_t i) const
return m_results.at(i)->get_shape();
}
const PartialShape& Function::get_output_partial_shape(size_t i) const
{
return m_results.at(i)->get_output_partial_shape(0);
}
shared_ptr<Node> Function::get_output_op(size_t i) const
{
return m_results.at(i);
......
......@@ -61,6 +61,9 @@ namespace ngraph
/// Return the shape of element i
const Shape& get_output_shape(size_t i) const;
/// Return the partial shape of element i
const PartialShape& get_output_partial_shape(size_t i) const;
/// Return the function parameters
const op::ParameterVector& get_parameters() const { return m_parameters; }
/// Return a list of function's outputs
......
......@@ -246,6 +246,11 @@ const Shape& Node::get_output_shape(size_t i) const
return m_outputs.at(i).get_shape();
}
const PartialShape& Node::get_output_partial_shape(size_t i) const
{
return m_outputs.at(i).get_partial_shape();
}
const Shape& Node::get_shape() const
{
if (get_output_size() != 1)
......@@ -307,6 +312,11 @@ const Shape& Node::get_input_shape(size_t i) const
return m_inputs.at(i).get_shape();
}
const PartialShape& Node::get_input_partial_shape(size_t i) const
{
return m_inputs.at(i).get_partial_shape();
}
bool Node::has_same_type(std::shared_ptr<const Node> node) const
{
if (get_output_size() != node->get_output_size())
......
......@@ -168,6 +168,9 @@ namespace ngraph
/// Returns the shape for output i
const Shape& get_output_shape(size_t i) const;
/// Returns the partial shape for output i
const PartialShape& get_output_partial_shape(size_t i) const;
/// Checks that there is exactly one output and returns its shape
const Shape& get_shape() const;
......@@ -195,6 +198,9 @@ namespace ngraph
/// Returns the shape of input i
const Shape& get_input_shape(size_t i) const;
/// Returns the partial shape of input i
const PartialShape& get_input_partial_shape(size_t i) const;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
......
......@@ -22,6 +22,12 @@
using namespace ngraph;
PartialShape::PartialShape(const Shape& shape)
: PartialShape(true, {})
{
m_dimensions.assign(shape.begin(), shape.end());
}
bool ngraph::PartialShape::is_complete() const
{
return m_rank_is_determined &&
......@@ -102,3 +108,13 @@ bool PartialShape::compatible(const PartialShape& s) const
return true;
}
}
Shape PartialShape::to_shape() const
{
if (!is_complete())
{
throw std::invalid_argument("to_shape was called on an incomplete shape.");
}
return Shape(m_dimensions.begin(), m_dimensions.end());
}
......@@ -20,6 +20,7 @@
#include "ngraph/dimension.hpp"
#include "ngraph/rank.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
......@@ -47,6 +48,9 @@ namespace ngraph
{
}
/// \brief Constructs a complete PartialShape from a Shape.
PartialShape(const Shape& shape);
/// \brief Returns true if the shape has determined rank.
bool rank_is_determined() const { return m_rank_is_determined; }
/// \brief Returns true if the shape has known rank and all dimensions of the shape
......@@ -75,6 +79,11 @@ namespace ngraph
/// they are both determined and equal.
bool compatible(const PartialShape& s) const;
/// \brief Converts a complete PartialShape to a Shape.
///
/// Throws std::invalid_argument if the PartialShape is incomplete.
Shape to_shape() const;
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
......
......@@ -200,3 +200,66 @@ TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims)
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, from_shape)
{
Shape s{2, 4, 6, 8};
PartialShape ps1{s};
// TODO(amprocte): No way to examine contents of ps1 yet.
ASSERT_TRUE(ps1.is_complete());
ASSERT_TRUE(ps1.rank_is_determined());
ASSERT_EQ(size_t(ps1.rank()), s.size());
}
TEST(partial_shape, to_shape_complete)
{
PartialShape ps{2, 4, 6, 8};
Shape s{ps.to_shape()};
ASSERT_EQ(s, (Shape{2, 4, 6, 8}));
}
TEST(partial_shape, to_shape_dims_undetermined)
{
PartialShape ps{2, 4, Dimension::undetermined(), 8};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
}
TEST(partial_shape, to_shape_rank_undetermined)
{
PartialShape ps{PartialShape::undetermined()};
ASSERT_THROW({ ps.to_shape(); }, std::invalid_argument);
}
TEST(partial_shape, tensor_descriptor_from_shape)
{
descriptor::Tensor t{element::i32, Shape{1, 2, 3}, "Ankeny"};
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
}
TEST(partial_shape, tensor_descriptor_from_complete_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape{1, 2, 3}, "Burnside"};
ASSERT_EQ(t.get_shape(), (Shape{1, 2, 3}));
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
}
TEST(partial_shape, tensor_descriptor_from_incomplete_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape{1, Dimension::undetermined(), 3}, "Couch"};
ASSERT_EQ(size_t(t.get_partial_shape().rank()), 3);
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
}
TEST(partial_shape, tensor_descriptor_from_rankless_partial_shape)
{
descriptor::Tensor t{element::i32, PartialShape::undetermined(), "Davis"};
ASSERT_FALSE(t.get_partial_shape().rank().is_determined());
ASSERT_THROW({ t.get_shape(); }, std::invalid_argument);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment