Unverified Commit bcddc600 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge pull request #1692 from NervanaSystems/aprocter/partial-shape

Partial Shapes, Part 1: Classes for partially known shapes, possibly unknown dimensions
parents c7183e46 7277f0e8
......@@ -29,6 +29,7 @@ set (SRC
descriptor/layout/tensor_layout.cpp
descriptor/output.cpp
descriptor/tensor.cpp
dimension.cpp
file_util.cpp
function.cpp
log.cpp
......@@ -112,6 +113,7 @@ set (SRC
op/util/binary_elementwise_logical.cpp
op/util/index_reduction.cpp
op/util/unary_elementwise_arithmetic.cpp
partial_shape.cpp
pass/assign_placement.cpp
pass/algebraic_simplification.cpp
pass/common_function_collection.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
#include <limits>
#include <sstream>
#include "ngraph/dimension.hpp"
using namespace ngraph;
Dimension::Dimension(size_t dimension)
: m_dimension(dimension)
{
if (dimension == s_undetermined_val)
{
std::stringstream ss;
ss << "Cannot convert the value 0x" << std::uppercase << std::hex << s_undetermined_val
<< " to Dimension: this value is used internally to represent an undetermined "
"dimension.";
throw std::invalid_argument(ss.str());
}
}
std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{
if (dimension.is_determined())
{
return (str << size_t(dimension));
}
else
{
return (str << "?");
}
}
Dimension ngraph::operator+(const Dimension& d1, const Dimension& d2)
{
return (d1.is_determined() && d2.is_determined() ? size_t(d1) + size_t(d2)
: Dimension::undetermined());
}
bool Dimension::compatible(const Dimension& d) const
{
return (!is_determined() || !d.is_determined() || m_dimension == size_t(d));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <limits>
#include <stddef.h>
#include <stdexcept>
namespace ngraph
{
/// \brief Class representing a possibly-unknown dimension in a shape or shape-like object.
///
/// Known dimensions may be implicitly converted from size_t. An unknown dimension is
/// constructed with Dimension() or Dimension::undetermined().
///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
class Dimension
{
public:
/// \brief Constructs a known dimension.
///
/// Requires that dimension != s_undetermined_val. If that condition does not hold,
/// throws std::invalid_argument.
Dimension(size_t dimension);
/// \brief Constructs an unknown dimension.
Dimension() { m_dimension = s_undetermined_val; }
/// \brief Returns true if this dimension is determined.
bool is_determined() const { return m_dimension != s_undetermined_val; }
/// \brief Converts this dimension to size_t. If the dimension is undetermined, throws
/// std::invalid_argument.
explicit operator size_t() const
{
if (!is_determined())
{
throw std::invalid_argument("Cannot convert unknown dimension to size_t");
}
return m_dimension;
}
/// \brief Returns true if the dimensions are compatible, i.e. if one of the dimensions
/// is undetermined, or both dimensions are determined and equal.
bool compatible(const Dimension& d) const;
/// \brief Constructs an unknown dimension.
static Dimension undetermined() { return Dimension(); }
/// \brief Constant for the value used internally to represent an unknown dimension.
static const size_t s_undetermined_val{std::numeric_limits<size_t>::max()};
private:
// The actual numerical value of the dimension. s_undetermined_val is a special case,
// representing an unknown dimension.
size_t m_dimension;
};
/// \brief Inserts a human-readable representation of "dimension" into "str".
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
/// \brief Addition operator for dimensions.
///
/// If d1 and d2 are both known, returns size_t(d1)+size_t(d2). Otherwise, returns
/// Dimension::undetermined().
Dimension operator+(const Dimension& d1, const Dimension& d2);
}
......@@ -53,6 +53,7 @@
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/dimension.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
......@@ -128,6 +129,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/shape.hpp"
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include <vector>
#include "ngraph/partial_shape.hpp"
using namespace ngraph;
bool ngraph::PartialShape::is_complete() const
{
return m_rank_is_determined &&
std::all_of(m_dimensions.begin(), m_dimensions.end(), [](const Dimension& d) {
return d.is_determined();
});
}
PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{
if (!s1.rank_is_determined() || !s2.rank_is_determined())
{
return PartialShape::undetermined();
}
if (!s1.rank().compatible(s2.rank()))
{
throw std::invalid_argument("rank mismatch");
}
PartialShape result{};
result.m_rank_is_determined = true;
for (size_t i = 0; i < s1.m_dimensions.size(); i++)
{
result.m_dimensions.push_back(s1.m_dimensions[i] + s2.m_dimensions[i]);
}
return result;
}
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{
if (shape.m_rank_is_determined)
{
str << "{";
bool first = true;
for (auto& d : shape.m_dimensions)
{
if (!first)
{
str << ",";
}
str << d;
first = false;
}
return (str << "}");
}
else
{
return (str << "?");
}
}
bool PartialShape::compatible(const PartialShape& s) const
{
// If we don't know *this's rank, or we don't know s's rank, they are compatible.
if (!rank_is_determined() || !s.rank_is_determined())
{
return true;
}
// If we do know *this's rank and s's rank, and they are unequal, they are incompatible.
else if (size_t(rank()) != size_t(s.rank()))
{
return false;
}
// If we know both the ranks and they are equal, then *this and s are compatible iff they
// are elementwise compatible everywhere.
else
{
for (size_t i = 0; i < size_t(rank()); i++)
{
if (!m_dimensions[i].compatible(s.m_dimensions[i]))
{
return false;
}
}
// If we are still here, we know that s1 and s2 have the same rank and are elementwise
// compatible everywhere.
return true;
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <stddef.h>
#include "ngraph/dimension.hpp"
#include "ngraph/rank.hpp"
namespace ngraph
{
/// \brief Class representing a shape that may only be partially known.
///
/// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
///
/// A partially-known shape may have:
///
/// - Unknown rank.
/// - Known rank, but unknown dimensions on some or all axes.
/// - Known rank, and known dimensions on all axes.
class PartialShape
{
public:
/// \brief Constructs a shape with determined rank.
///
/// Examples:
///
/// PartialShape s{2,3,4}; // rank=3, all dimensions determined
/// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::undetermined(),3}; // rank=2, dimension 1 undetermined
PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init)
{
}
/// \brief Returns true if the shape has determined rank.
bool rank_is_determined() const { return m_rank_is_determined; }
/// \brief Returns true if the shape has known rank and all dimensions of the shape
/// are determined.
bool is_complete() const;
/// \brief Returns the rank of the shape. Returns Rank::undetermined() if the rank is undetermined.
Rank rank() const
{
return m_rank_is_determined ? Rank(m_dimensions.size()) : Rank::undetermined();
}
/// \brief Appends another shape to this shape.
///
/// If "this" and "other" both have determined rank, returns a new shape two shape
/// whose dimensions are the concatenation of the dimensions of "this" and "other".
/// If either "this" or "other" has undetermined rank, returns
/// PartialShape::undetermined().
PartialShape append(const PartialShape& other);
/// \brief Returns the undetermined shape.
static PartialShape undetermined() { return PartialShape(false, {}); }
/// \brief Returns true if *this is compatible with s.
///
/// Two dimensions are compatible if one or both of them is undetermined, or if
/// they are both determined and equal.
bool compatible(const PartialShape& s) const;
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
private:
// Private constructor so PartialShape::undetermined() can construct an undetermined shape.
PartialShape(bool rank_is_determined, std::initializer_list<Dimension> init)
: m_rank_is_determined(rank_is_determined)
, m_dimensions(init)
{
}
// True if the shape's rank is determined.
bool m_rank_is_determined;
// Shape dimensions. This has no meaning if m_rank_is_determined is false.
std::vector<Dimension> m_dimensions;
};
/// \brief Elementwise addition of two shapes.
///
/// If s1 or s2 has undetermined rank, returns PartialShape::undetermined().
/// If s1 and s2 both have determined rank, and their ranks are unequal,
/// throws std::invalid_argument.
/// If s1 and s2 both have determined rank, and their ranks are equal,
/// returns a new shape whose ith dimension is s1[i] + s2[i].
PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
/// \brief Inserts a human-readable representation of "shape" into "str".
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/dimension.hpp"
namespace ngraph
{
/// \brief Alias for "Dimension". Should be used to when the value represents the number of
/// axes in a shape-like object, rather than the size of one dimension in a shape-like
/// object.
///
/// XXX: THIS TYPE IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
using Rank = Dimension;
}
......@@ -34,6 +34,7 @@ set(SRC
main.cpp
nop_elimination.cpp
op.cpp
partial_shape.cpp
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
TEST(partial_shape, ps_construction_empty)
{
auto ps = PartialShape{};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 0);
}
TEST(partial_shape, ps_construction_undetermined)
{
auto ps = PartialShape::undetermined();
ASSERT_FALSE(ps.rank_is_determined());
ASSERT_FALSE(ps.rank().is_determined());
ASSERT_FALSE(ps.is_complete());
}
TEST(partial_shape, ps_construction_incomplete)
{
auto ps = PartialShape{2, Dimension::undetermined(), 3};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_FALSE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 3);
}
TEST(partial_shape, ps_construction_complete)
{
auto ps = PartialShape{2, 5, 3, 6};
ASSERT_TRUE(ps.rank_is_determined());
ASSERT_TRUE(ps.rank().is_determined());
ASSERT_TRUE(ps.is_complete());
ASSERT_EQ(size_t(ps.rank()), 4);
}
TEST(partial_shape, dim_construction_determined)
{
Dimension dim{3};
ASSERT_EQ(size_t(dim), 3);
ASSERT_TRUE(dim.is_determined());
}
TEST(partial_shape, dim_construction_undetermined)
{
Dimension dim = Dimension::undetermined();
ASSERT_FALSE(dim.is_determined());
}
TEST(partial_shape, dim_construction_size_t_max)
{
EXPECT_ANY_THROW({ Dimension d{Dimension::s_undetermined_val}; });
}
TEST(partial_shape, dim_conversion_determined)
{
Dimension d{42};
size_t s{d};
ASSERT_EQ(s, 42);
}
TEST(partial_shape, dim_conversion_undetermined)
{
EXPECT_ANY_THROW({
size_t s{Dimension::undetermined()};
s = 0; // Silence compiler warning about unused s
});
}
TEST(partial_shape, rank_construction_determined)
{
Rank r{4};
ASSERT_EQ(size_t(r), 4);
ASSERT_TRUE(r.is_determined());
}
TEST(partial_shape, rank_construction_undetermined)
{
Rank r = Rank::undetermined();
ASSERT_FALSE(r.is_determined());
}
TEST(partial_shape, dim_compatible_left_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{3};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_right_undetermined)
{
Dimension d1{3};
Dimension d2{Dimension::undetermined()};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_both_undetermined)
{
Dimension d1{Dimension::undetermined()};
Dimension d2{Dimension::undetermined()};
ASSERT_TRUE(d1.compatible(d2));
}
TEST(partial_shape, dim_compatible_both_determined)
{
Dimension d1{3};
Dimension d2{8};
Dimension d3{3};
ASSERT_FALSE(d1.compatible(d2));
ASSERT_TRUE(d1.compatible(d3));
}
TEST(partial_shape, shapes_compatible_both_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_left_rank_undetermined)
{
PartialShape ps1{3};
PartialShape ps2{PartialShape::undetermined()};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_right_rank_undetermined)
{
PartialShape ps1{PartialShape::undetermined()};
PartialShape ps2{4};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_both_partial_all_known_equal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{2, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_both_partial_some_known_unequal)
{
PartialShape ps1{2, Dimension::undetermined(), 3, Dimension::undetermined(), 5};
PartialShape ps2{1, Dimension::undetermined(), Dimension::undetermined(), 4, 5};
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_compatible_both_complete_different_rank)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8, 10};
ASSERT_FALSE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_same_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 6, 8};
ASSERT_TRUE(ps1.compatible(ps2));
}
TEST(partial_shape, shapes_equal_both_complete_same_rank_different_dims)
{
PartialShape ps1{2, 4, 6, 8};
PartialShape ps2{2, 4, 3, 8};
ASSERT_FALSE(ps1.compatible(ps2));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment