Commit 86fcc656 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ShapeRelevance pass (#2817)

* Add ShapeRelevance pass

* Typo
parent 41a6b782
......@@ -342,6 +342,8 @@ set (SRC
pass/reshape_sinking.hpp
pass/serialize.cpp
pass/serialize.hpp
pass/shape_relevance.cpp
pass/shape_relevance.hpp
pass/shape_specialization.cpp
pass/shape_specialization.hpp
pass/validate_graph.cpp
......
......@@ -94,6 +94,8 @@ void op::DynPad::validate_and_infer_types()
auto out_shape = PartialShape::dynamic(output_rank);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, arg_t, out_shape);
}
......
......@@ -44,6 +44,8 @@ void op::DynReshape::validate_and_infer_types()
pattern_shape.rank(),
".");
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
}
......
......@@ -76,6 +76,9 @@ void op::DynSlice::validate_and_infer_types()
strides_shape.compatible(PartialShape{arg_shape.rank()}),
"Strides shape must have shape [n], where n is the rank of arg.");
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
......
......@@ -28,6 +28,7 @@ op::Parameter::Parameter(const element::Type& element_type,
, m_cacheable(cacheable)
, m_partial_shape(pshape)
, m_element_type(element_type)
, m_is_relevant_to_shapes(false)
{
constructor_validate_and_infer_types();
}
......@@ -48,3 +49,13 @@ void op::Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
{
auto delta = deltas.at(0);
}
bool op::Parameter::is_relevant_to_shapes() const
{
return m_is_relevant_to_shapes;
}
void op::Parameter::set_is_relevant_to_shapes(bool is_relevant)
{
m_is_relevant_to_shapes = is_relevant;
}
......@@ -50,10 +50,14 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool is_relevant_to_shapes() const;
void set_is_relevant_to_shapes(bool is_relevant);
protected:
bool m_cacheable;
PartialShape m_partial_shape;
element::Type m_element_type;
bool m_is_relevant_to_shapes;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
using namespace ngraph;
//
// This pass refreshes the "is_relevant_to_shape" flag on each parameter. A parameter will be
// flagged as relevant to shapes if there is any path from that parameter to a shape-relevant
// input that does _not_ pass through a value-irrelevant input. For example:
//
// N0[Parameter] N1[Parameter]
// | |
// | |
// | |
// N2[DynReshape]
//
// N1 (but not N0) will be flagged as shape-relevant, because N1 feeds into the "shape" input
// of N2.
//
// N0[Parameter] N1[Parameter]
// | |
// | N2[ShapeOf]
// | |
// N3[DynReshape]
//
// Neither N0 nor N1 will be flagged as shape-relevant. (N1 does feed into the "shape" input of N3,
// but only via the value-irrelevant input of ShapeOf.)
//
bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
{
// TODO(amprocte): We are probably reinventing the wheel with the graph traversal here; the
// reason is that we need to cut the traversal short in cases where input values are
// irrelevant. See if there is a way to reduce this duplication.
// Set of nodes that must be evaluated to determine the value of shape-relevant inputs.
std::set<Node*> shape_determinants;
// Step 1: Find root nodes (these are nodes with an output connected to a shape-relevant
// input).
for (auto& n : f->get_ops())
{
for (auto& output : n->outputs())
{
for (auto& input : output.get_target_inputs())
{
if (input.get_is_relevant_to_shapes())
{
shape_determinants.insert(n.get());
break;
}
}
}
}
// Step 2: Find all shape determinants. This is the transitive closure of R, where n1 R n2
// iff there is a data flow edge from n2 to n1 and that data flow edge is not
// value-irrelevant.
bool changes_made = false;
{
std::list<Node*> to_visit{shape_determinants.begin(), shape_determinants.end()};
std::set<Node*> already_visited;
while (!to_visit.empty())
{
auto node = to_visit.front();
to_visit.pop_front();
if (already_visited.count(node) > 0)
{
continue;
}
shape_determinants.insert(node);
already_visited.insert(node);
if (node->is_parameter())
{
auto node_as_param = static_cast<op::Parameter*>(node);
if (!node_as_param->is_relevant_to_shapes())
{
node_as_param->set_is_relevant_to_shapes(true);
changes_made = true;
}
}
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (!node->input(i).get_is_relevant_to_values())
{
continue;
}
auto source_node = node->input(i).get_source_output().get_node();
if (already_visited.count(source_node) == 0)
{
to_visit.push_front(source_node);
}
}
}
}
return changes_made;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ShapeRelevance : public FunctionPass
{
public:
ShapeRelevance()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
}
}
......@@ -55,6 +55,7 @@ set(SRC
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
pass_shape_relevance.cpp
pass_shape_specialization.cpp
pattern.cpp
provenance.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/shape_relevance.hpp"
using namespace ngraph;
using namespace std;
TEST(shape_relevance, simple)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto param1 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto x = make_shared<op::Add>(param0, param1);
auto f = make_shared<Function>(x, ParameterVector{param0, param1});
pass::Manager manager;
manager.register_pass<pass::ShapeRelevance>();
manager.run_passes(f);
ASSERT_FALSE(param0->is_relevant_to_shapes());
ASSERT_FALSE(param1->is_relevant_to_shapes());
}
TEST(shape_relevance, param_direct)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto param1 = make_shared<op::Parameter>(element::i64, Shape{4});
auto x = make_shared<op::DynReshape>(param0, param1);
auto f = make_shared<Function>(x, ParameterVector{param0, param1});
pass::Manager manager;
manager.register_pass<pass::ShapeRelevance>();
manager.run_passes(f);
ASSERT_FALSE(param0->is_relevant_to_shapes());
ASSERT_TRUE(param1->is_relevant_to_shapes());
}
TEST(shape_relevance, param_indirect)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto param1 = make_shared<op::Parameter>(element::i64, Shape{4});
auto param2 = make_shared<op::Parameter>(element::i64, Shape{2});
auto c = make_shared<op::Concat>(NodeVector{param1, param2}, 0);
auto x = make_shared<op::DynReshape>(param0, c);
auto f = make_shared<Function>(x, ParameterVector{param0, param1, param2});
pass::Manager manager;
manager.register_pass<pass::ShapeRelevance>();
manager.run_passes(f);
ASSERT_FALSE(param0->is_relevant_to_shapes());
ASSERT_TRUE(param1->is_relevant_to_shapes());
ASSERT_TRUE(param2->is_relevant_to_shapes());
}
TEST(shape_relevance, param_shape_of_direct)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto x = make_shared<op::DynReshape>(param0, make_shared<op::ShapeOf>(param0));
auto f = make_shared<Function>(x, ParameterVector{param0});
pass::Manager manager;
manager.register_pass<pass::ShapeRelevance>();
manager.run_passes(f);
ASSERT_FALSE(param0->is_relevant_to_shapes());
}
TEST(shape_relevance, param_shape_of_indirect)
{
auto param0 = make_shared<op::Parameter>(element::f32, Shape{4, 6});
auto s = make_shared<op::ShapeOf>(param0);
auto r = make_shared<op::Reverse>(s, AxisSet{0});
auto x = make_shared<op::DynReshape>(param0, r);
auto f = make_shared<Function>(x, ParameterVector{param0});
pass::Manager manager;
manager.register_pass<pass::ShapeRelevance>();
manager.run_passes(f);
ASSERT_FALSE(param0->is_relevant_to_shapes());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment