Commit 870f5000 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

dynamic shape support for algebraic_simplification pass (#2804)

* - add support to algebriac simplification to check dynamic shapes against the Pass Properties

* - added PassProperty to RecurrentMatcher
- added checks to check for dynamic function state  and PassProperty in GraphRewrite before applying graph optimization
- optimize number of calls to f->is_dynamic() in AlgebraicSimplification

* - make changes to Algebriac Simplicfication graph pass to accept PassProperty

* - test case for AlgebraicSimplification pass properties
- set the Pass Property in the pass ctor

* Address PR comments

* fix bug in pass manager

* Addressed PR comments
parent 2147309f
......@@ -17,6 +17,7 @@
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -32,7 +33,9 @@ public:
AlgebraicSimplification()
: FunctionPass()
{
pass::PassPropertyMask property{pass::PassProperty::REGULAR_FUSIONS,
pass::PassProperty::REQUIRE_STATIC_SHAPE};
ngraph::pass::PassBase::set_property(property, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -65,6 +65,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES;
vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
bool is_dynamic_function = f->is_dynamic();
do
{
rewritten = false;
......@@ -74,6 +75,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{
for (auto matcher : matchers)
{
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher->get_name() << "("
<< matcher->get_pattern()->get_name() << ") on " << node->get_name();
if (matcher->match(node))
......@@ -83,6 +92,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
if (matcher->process_match())
{
rewritten = true;
is_dynamic_function = f->is_dynamic();
break;
}
}
......@@ -141,12 +151,21 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool changed = false;
size_t i = 0;
bool is_dynamic_function = f->is_dynamic();
do
{
for (auto node : f->get_ops())
{
for (auto matcher : m_matchers)
{
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node->get_name();
if (matcher->match(node))
{
......@@ -154,6 +173,7 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
if (matcher->process_match())
{
changed = true;
is_dynamic_function = f->is_dynamic();
goto next_fusion;
}
}
......
......@@ -61,17 +61,26 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
vector<shared_ptr<Function>> fs;
vector<std::pair<shared_ptr<Function>, bool>> fs;
if (transitive)
{
// find all functions
traverse_functions(func, [&](shared_ptr<Function> f) { fs.push_back(f); });
traverse_functions(func, [&](shared_ptr<Function> f) {
fs.push_back(std::make_pair(f, f->is_dynamic()));
});
}
else
{
fs = {func};
fs = {std::make_pair(func, func->is_dynamic())};
}
set<shared_ptr<Function>> tfs;
std::vector<shared_ptr<Function>> f_array;
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
tfs.insert(f);
f_array.push_back(f);
}
set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs);
size_t index = 0;
......@@ -92,19 +101,30 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{
vt_pass->set_ops_to_details(get_state().get_visualize_tree_ops_map());
}
module_pass->run_on_module(fs);
module_pass->run_on_module(f_array);
}
else if (function_pass)
{
for (shared_ptr<Function> f : fs)
for (auto f_pair : fs)
{
function_pass->run_on_function(f);
shared_ptr<Function> f = f_pair.first;
// This checks is to skip the graph optimization when the graph pass relies on static shape
// but the function state is dynamic.
// we update the function dynamic state only if we run the graph pass successfully.
if (function_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
}
bool function_modified = function_pass->run_on_function(f);
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
}
}
else if (node_pass)
{
for (shared_ptr<Function> f : fs)
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
for (shared_ptr<Node> n : f->get_ops())
{
node_pass->run_on_node(n);
......@@ -113,15 +133,23 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
}
else if (call_graph_pass)
{
for (shared_ptr<Function> f : fs)
for (auto f_pair : fs)
{
call_graph_pass->run_on_call_graph(f->get_ordered_ops());
shared_ptr<Function> f = f_pair.first;
if (call_graph_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
}
bool function_modified = call_graph_pass->run_on_call_graph(f->get_ordered_ops());
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
}
}
// Better to do this in node replacement but this will do for now
for (auto f : fs)
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
f->validate_nodes_and_infer_types();
}
......@@ -131,21 +159,21 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = fs.at(0)->get_name() + std::string("_") + index_str +
auto base_filename = f_array.at(0)->get_name() + std::string("_") + index_str +
std::string("_") + m_pass_names.at(index) + std::string(".");
if (m_visualize)
{
pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext());
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
vt.run_on_module(fs);
vt.run_on_module(f_array);
}
if (m_serialize)
{
// no "." in the extension
pass::Serialization st(base_filename + "json");
st.run_on_module(fs);
st.run_on_module(f_array);
}
}
index++;
......
......@@ -397,6 +397,11 @@ namespace ngraph
return is_match;
}
bool Matcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{
bool matched = false;
......@@ -454,5 +459,9 @@ namespace ngraph
}
bool RecurrentMatcher::process_match() { return m_callback(*this); }
bool RecurrentMatcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
}
}
......@@ -193,11 +193,13 @@ namespace ngraph
RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns,
recurrent_graph_rewrite_callback callback)
recurrent_graph_rewrite_callback callback,
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS)
: m_pattern(pattern)
, m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns)
, m_callback(callback)
, m_property(property)
{
}
......@@ -230,6 +232,8 @@ namespace ngraph
/// \brief Invoked by a pass to process a successful match
bool process_match();
bool get_property(const pass::PassPropertyMask& prop) const;
std::shared_ptr<Node> get_match_root() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
......@@ -237,6 +241,7 @@ namespace ngraph
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback;
pass::PassPropertyMask m_property;
std::shared_ptr<Node> m_match_root;
};
}
......
......@@ -43,6 +43,7 @@
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
......@@ -583,3 +584,13 @@ TEST(algebraic_simplification, log_no_divide)
pass_manager.run_passes(f);
ASSERT_EQ(neg_inner->get_argument(0), log_mul);
}
TEST(algebraic_simplification, pass_property)
{
auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true,
pass->get_property(pass::PassPropertyMask(pass::PassProperty::REGULAR_FUSIONS) |
pass::PassPropertyMask(pass::PassProperty::REQUIRE_STATIC_SHAPE)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment