Commit 4db9fac4 authored by Adam Procter's avatar Adam Procter

Add experimental partial-shape classes

parent f516ba5f
...@@ -31,6 +31,7 @@ set (SRC ...@@ -31,6 +31,7 @@ set (SRC
descriptor/tensor.cpp descriptor/tensor.cpp
file_util.cpp file_util.cpp
function.cpp function.cpp
length.cpp
log.cpp log.cpp
node.cpp node.cpp
op/abs.cpp op/abs.cpp
...@@ -110,6 +111,7 @@ set (SRC ...@@ -110,6 +111,7 @@ set (SRC
op/util/binary_elementwise_logical.cpp op/util/binary_elementwise_logical.cpp
op/util/index_reduction.cpp op/util/index_reduction.cpp
op/util/unary_elementwise_arithmetic.cpp op/util/unary_elementwise_arithmetic.cpp
partial_shape.cpp
pass/assign_placement.cpp pass/assign_placement.cpp
pass/algebraic_simplification.cpp pass/algebraic_simplification.cpp
pass/common_function_collection.cpp pass/common_function_collection.cpp
...@@ -135,6 +137,7 @@ set (SRC ...@@ -135,6 +137,7 @@ set (SRC
pass/serialize.cpp pass/serialize.cpp
pass/zero_dim_tensor_elimination.cpp pass/zero_dim_tensor_elimination.cpp
pattern/matcher.cpp pattern/matcher.cpp
rank.cpp
runtime/aligned_buffer.cpp runtime/aligned_buffer.cpp
runtime/backend.cpp runtime/backend.cpp
runtime/backend_manager.cpp runtime/backend_manager.cpp
...@@ -144,6 +147,7 @@ set (SRC ...@@ -144,6 +147,7 @@ set (SRC
shape.cpp shape.cpp
strides.cpp strides.cpp
type/element_type.cpp type/element_type.cpp
undetermined.cpp
util.cpp util.cpp
graph_util.cpp graph_util.cpp
placement.cpp placement.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/length.hpp"
std::ostream& ngraph::operator<<(std::ostream& str, const Length& length)
{
if (length.fixed())
{
return (str << size_t(length));
}
else
{
return (str << "?");
}
}
ngraph::Length ngraph::operator+(const Length& l1, const Length& l2)
{
return (l1.fixed() && l2.fixed() ? size_t(l1) + size_t(l2) : Length(undet));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
#pragma once
#include <stddef.h>
#include "ngraph/undetermined.hpp"
namespace ngraph
{
class Length
{
public:
Length(size_t length)
: m_length(length)
, m_fixed(true)
{
}
Length(const Undetermined&)
: m_length(0)
, m_fixed(false)
{
}
Length()
: m_length(0)
, m_fixed(true)
{
}
bool fixed() const { return m_fixed; }
explicit operator size_t() const { return m_length; }
private:
size_t m_length;
bool m_fixed;
};
std::ostream& operator<<(std::ostream& str, const Length& length);
Length operator+(const Length& l1, const Length& l2);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <vector>
#include "ngraph/partial_shape.hpp"
using namespace ngraph;
bool ngraph::PartialShape::fixed() const
{
return m_rank_fixed && std::all_of(m_lengths.begin(), m_lengths.end(), [](const Length& l) {
return l.fixed();
});
}
ngraph::PartialShape ngraph::operator+(const PartialShape& s1, const PartialShape& s2)
{
if (!s1.rank_fixed() || !s2.rank_fixed())
{
return undet;
}
if (s1.rank() != s2.rank())
{
throw std::invalid_argument("rank mismatch");
}
PartialShape result{};
result.m_rank_fixed = true;
for (size_t i = 0; i < s1.m_lengths.size(); i++)
{
result.m_lengths.push_back(s1.m_lengths[i] + s2.m_lengths[i]);
}
return result;
}
std::ostream& ngraph::operator<<(std::ostream& str, const PartialShape& shape)
{
if (shape.m_rank_fixed)
{
str << "{";
bool first = true;
for (auto& l : shape.m_lengths)
{
if (!first)
{
str << ",";
}
str << l;
first = false;
}
return (str << "}");
}
else
{
return (str << "?");
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
#pragma once
#include <stddef.h>
#include "ngraph/length.hpp"
#include "ngraph/rank.hpp"
namespace ngraph
{
class PartialShape
{
public:
PartialShape(std::initializer_list<Length> init)
: m_rank_fixed(true)
, m_lengths(init)
{
}
PartialShape(const Undetermined&)
: m_rank_fixed(false)
, m_lengths()
{
}
bool rank_fixed() const { return m_rank_fixed; }
bool fixed() const;
Rank rank() const { return m_rank_fixed ? Rank(m_lengths.size()) : undet; }
friend std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
PartialShape append(const PartialShape& other);
private:
bool m_rank_fixed;
std::vector<Length> m_lengths;
};
PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/rank.hpp"
std::ostream& ngraph::operator<<(std::ostream& str, const Rank& rank)
{
if (rank.fixed())
{
return (str << size_t(rank));
}
else
{
return (str << "?");
}
}
bool ngraph::operator==(const Rank& r1, const Rank& r2)
{
return (r1.fixed() && r2.fixed() && size_t(r1) == size_t(r2));
}
bool ngraph::operator!=(const Rank& r1, const Rank& r2)
{
return (r1.fixed() && r2.fixed() && size_t(r1) != size_t(r2));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
#pragma once
#include <stddef.h>
#include "ngraph/undetermined.hpp"
namespace ngraph
{
class Rank
{
public:
Rank(size_t rank)
: m_rank(rank)
, m_fixed(true)
{
}
Rank(const Undetermined&)
: m_rank(0)
, m_fixed(false)
{
}
Rank()
: m_rank(0)
, m_fixed(true)
{
}
bool fixed() const { return m_fixed; }
explicit operator size_t() const { return m_rank; }
private:
size_t m_rank;
bool m_fixed;
};
std::ostream& operator<<(std::ostream& str, const Rank& rank);
bool operator==(const Rank& r1, const Rank& r2);
bool operator!=(const Rank& r1, const Rank& r2);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/undetermined.hpp"
std::ostream& ngraph::operator<<(std::ostream& str, const Undetermined&)
{
return (str << "?");
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
#pragma once
#include <iostream>
namespace ngraph
{
class Undetermined
{
public:
Undetermined() {}
friend std::ostream& operator<<(std::ostream& str, const Undetermined&);
};
std::ostream& operator<<(std::ostream& str, const Undetermined&);
const ngraph::Undetermined undet;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment