Commit 870f5000 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

dynamic shape support for algebraic_simplification pass (#2804)

* - add support to algebriac simplification to check dynamic shapes against the Pass Properties

* - added PassProperty to RecurrentMatcher
- added checks to check for dynamic function state  and PassProperty in GraphRewrite before applying graph optimization
- optimize number of calls to f->is_dynamic() in AlgebraicSimplification

* - make changes to Algebriac Simplicfication graph pass to accept PassProperty

* - test case for AlgebraicSimplification pass properties
- set the Pass Property in the pass ctor

* Address PR comments

* fix bug in pass manager

* Addressed PR comments
parent 2147309f
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -32,7 +33,9 @@ public: ...@@ -32,7 +33,9 @@ public:
AlgebraicSimplification() AlgebraicSimplification()
: FunctionPass() : FunctionPass()
{ {
pass::PassPropertyMask property{pass::PassProperty::REGULAR_FUSIONS,
pass::PassProperty::REQUIRE_STATIC_SHAPE};
ngraph::pass::PassBase::set_property(property, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
}; };
...@@ -65,6 +65,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -65,6 +65,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
const size_t NUM_TRIES = 10; const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES; size_t tries = NUM_TRIES;
vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers}; vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
bool is_dynamic_function = f->is_dynamic();
do do
{ {
rewritten = false; rewritten = false;
...@@ -74,6 +75,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -74,6 +75,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
for (auto matcher : matchers) for (auto matcher : matchers)
{ {
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher->get_name() << "(" NGRAPH_DEBUG << "Running matcher " << matcher->get_name() << "("
<< matcher->get_pattern()->get_name() << ") on " << node->get_name(); << matcher->get_pattern()->get_name() << ") on " << node->get_name();
if (matcher->match(node)) if (matcher->match(node))
...@@ -83,6 +92,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -83,6 +92,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
if (matcher->process_match()) if (matcher->process_match())
{ {
rewritten = true; rewritten = true;
is_dynamic_function = f->is_dynamic();
break; break;
} }
} }
...@@ -141,12 +151,21 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -141,12 +151,21 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
bool changed = false; bool changed = false;
size_t i = 0; size_t i = 0;
bool is_dynamic_function = f->is_dynamic();
do do
{ {
for (auto node : f->get_ops()) for (auto node : f->get_ops())
{ {
for (auto matcher : m_matchers) for (auto matcher : m_matchers)
{ {
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node->get_name(); NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node->get_name();
if (matcher->match(node)) if (matcher->match(node))
{ {
...@@ -154,6 +173,7 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -154,6 +173,7 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
if (matcher->process_match()) if (matcher->process_match())
{ {
changed = true; changed = true;
is_dynamic_function = f->is_dynamic();
goto next_fusion; goto next_fusion;
} }
} }
......
...@@ -61,17 +61,26 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -61,17 +61,26 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{ {
bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr; bool profile_enabled = getenv("NGRAPH_PROFILE_PASS_ENABLE") != nullptr;
vector<shared_ptr<Function>> fs; vector<std::pair<shared_ptr<Function>, bool>> fs;
if (transitive) if (transitive)
{ {
// find all functions // find all functions
traverse_functions(func, [&](shared_ptr<Function> f) { fs.push_back(f); }); traverse_functions(func, [&](shared_ptr<Function> f) {
fs.push_back(std::make_pair(f, f->is_dynamic()));
});
} }
else else
{ {
fs = {func}; fs = {std::make_pair(func, func->is_dynamic())};
}
set<shared_ptr<Function>> tfs;
std::vector<shared_ptr<Function>> f_array;
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
tfs.insert(f);
f_array.push_back(f);
} }
set<shared_ptr<Function>> tfs(begin(fs), end(fs));
get_state().set_functions(tfs); get_state().set_functions(tfs);
size_t index = 0; size_t index = 0;
...@@ -92,19 +101,30 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -92,19 +101,30 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
{ {
vt_pass->set_ops_to_details(get_state().get_visualize_tree_ops_map()); vt_pass->set_ops_to_details(get_state().get_visualize_tree_ops_map());
} }
module_pass->run_on_module(fs); module_pass->run_on_module(f_array);
} }
else if (function_pass) else if (function_pass)
{ {
for (shared_ptr<Function> f : fs) for (auto f_pair : fs)
{ {
function_pass->run_on_function(f); shared_ptr<Function> f = f_pair.first;
// This checks is to skip the graph optimization when the graph pass relies on static shape
// but the function state is dynamic.
// we update the function dynamic state only if we run the graph pass successfully.
if (function_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
}
bool function_modified = function_pass->run_on_function(f);
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
} }
} }
else if (node_pass) else if (node_pass)
{ {
for (shared_ptr<Function> f : fs) for (auto f_pair : fs)
{ {
shared_ptr<Function> f = f_pair.first;
for (shared_ptr<Node> n : f->get_ops()) for (shared_ptr<Node> n : f->get_ops())
{ {
node_pass->run_on_node(n); node_pass->run_on_node(n);
...@@ -113,15 +133,23 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -113,15 +133,23 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
} }
else if (call_graph_pass) else if (call_graph_pass)
{ {
for (shared_ptr<Function> f : fs) for (auto f_pair : fs)
{ {
call_graph_pass->run_on_call_graph(f->get_ordered_ops()); shared_ptr<Function> f = f_pair.first;
if (call_graph_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
}
bool function_modified = call_graph_pass->run_on_call_graph(f->get_ordered_ops());
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
} }
} }
// Better to do this in node replacement but this will do for now // Better to do this in node replacement but this will do for now
for (auto f : fs) for (auto f_pair : fs)
{ {
shared_ptr<Function> f = f_pair.first;
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
} }
...@@ -131,21 +159,21 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -131,21 +159,21 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
const size_t num_digits_in_pass_index = 3; const size_t num_digits_in_pass_index = 3;
std::string index_str = std::to_string(index); std::string index_str = std::to_string(index);
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str; index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
auto base_filename = fs.at(0)->get_name() + std::string("_") + index_str + auto base_filename = f_array.at(0)->get_name() + std::string("_") + index_str +
std::string("_") + m_pass_names.at(index) + std::string("."); std::string("_") + m_pass_names.at(index) + std::string(".");
if (m_visualize) if (m_visualize)
{ {
pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext()); pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext());
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map()); vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
vt.run_on_module(fs); vt.run_on_module(f_array);
} }
if (m_serialize) if (m_serialize)
{ {
// no "." in the extension // no "." in the extension
pass::Serialization st(base_filename + "json"); pass::Serialization st(base_filename + "json");
st.run_on_module(fs); st.run_on_module(f_array);
} }
} }
index++; index++;
......
...@@ -397,6 +397,11 @@ namespace ngraph ...@@ -397,6 +397,11 @@ namespace ngraph
return is_match; return is_match;
} }
bool Matcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph) bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{ {
bool matched = false; bool matched = false;
...@@ -454,5 +459,9 @@ namespace ngraph ...@@ -454,5 +459,9 @@ namespace ngraph
} }
bool RecurrentMatcher::process_match() { return m_callback(*this); } bool RecurrentMatcher::process_match() { return m_callback(*this); }
bool RecurrentMatcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
} }
} }
...@@ -193,11 +193,13 @@ namespace ngraph ...@@ -193,11 +193,13 @@ namespace ngraph
RecurrentMatcher(std::shared_ptr<Node> pattern, RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern, std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns, const std::set<std::shared_ptr<op::Label>>& correlated_patterns,
recurrent_graph_rewrite_callback callback) recurrent_graph_rewrite_callback callback,
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS)
: m_pattern(pattern) : m_pattern(pattern)
, m_recurrent_pattern(rpattern) , m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns) , m_correlated_patterns(correlated_patterns)
, m_callback(callback) , m_callback(callback)
, m_property(property)
{ {
} }
...@@ -230,6 +232,8 @@ namespace ngraph ...@@ -230,6 +232,8 @@ namespace ngraph
/// \brief Invoked by a pass to process a successful match /// \brief Invoked by a pass to process a successful match
bool process_match(); bool process_match();
bool get_property(const pass::PassPropertyMask& prop) const;
std::shared_ptr<Node> get_match_root() { return m_match_root; } std::shared_ptr<Node> get_match_root() { return m_match_root; }
private: private:
std::shared_ptr<Node> m_pattern; std::shared_ptr<Node> m_pattern;
...@@ -237,6 +241,7 @@ namespace ngraph ...@@ -237,6 +241,7 @@ namespace ngraph
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns; const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches; RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback; recurrent_graph_rewrite_callback m_callback;
pass::PassPropertyMask m_property;
std::shared_ptr<Node> m_match_root; std::shared_ptr<Node> m_match_root;
}; };
} }
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
...@@ -583,3 +584,13 @@ TEST(algebraic_simplification, log_no_divide) ...@@ -583,3 +584,13 @@ TEST(algebraic_simplification, log_no_divide)
pass_manager.run_passes(f); pass_manager.run_passes(f);
ASSERT_EQ(neg_inner->get_argument(0), log_mul); ASSERT_EQ(neg_inner->get_argument(0), log_mul);
} }
TEST(algebraic_simplification, pass_property)
{
auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true,
pass->get_property(pass::PassPropertyMask(pass::PassProperty::REGULAR_FUSIONS) |
pass::PassPropertyMask(pass::PassProperty::REQUIRE_STATIC_SHAPE)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment