Commit 391d50e0 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

[Dynamic Shape] GraphRewrite refactor with constant folding pass (#2860)

* add dynamic function support to GraphRewrite.

* add pass property to constant folding pass.

* namespace clean up.

* fixed comment.
parent 11c3ca8b
...@@ -33,9 +33,9 @@ public: ...@@ -33,9 +33,9 @@ public:
AlgebraicSimplification() AlgebraicSimplification()
: FunctionPass() : FunctionPass()
{ {
pass::PassPropertyMask property{pass::PassProperty::REGULAR_FUSIONS, PassPropertyMask property{PassProperty::REGULAR_FUSIONS,
pass::PassProperty::REQUIRE_STATIC_SHAPE}; PassProperty::REQUIRE_STATIC_SHAPE};
ngraph::pass::PassBase::set_property(property, true); set_property(property, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
}; };
...@@ -181,7 +181,7 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -181,7 +181,7 @@ void pass::ConstantFolding::construct_constant_pad()
}; };
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad"); auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback); this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::ConstantFolding::construct_constant_reshape() void pass::ConstantFolding::construct_constant_reshape()
...@@ -246,7 +246,8 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -246,7 +246,8 @@ void pass::ConstantFolding::construct_constant_reshape()
auto reshape_matcher = auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape"); make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(reshape_matcher, constant_reshape_callback); this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class T> template <class T>
...@@ -341,7 +342,8 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -341,7 +342,8 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto broadcast_matcher = auto broadcast_matcher =
make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast"); make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher(broadcast_matcher, constant_broadcast_callback); this->add_matcher(
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class T> template <class T>
...@@ -478,7 +480,8 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -478,7 +480,8 @@ void pass::ConstantFolding::construct_constant_binary()
}; };
auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary"); auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary");
this->add_matcher(reshape_matcher, constant_binary_callback); this->add_matcher(
reshape_matcher, constant_binary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
bool is_supported_unary_op(std::shared_ptr<Node> n) bool is_supported_unary_op(std::shared_ptr<Node> n)
...@@ -608,7 +611,7 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -608,7 +611,7 @@ void pass::ConstantFolding::construct_constant_unary()
}; };
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary"); auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback); this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class QUANT, class REAL> template <class QUANT, class REAL>
...@@ -681,7 +684,8 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -681,7 +684,8 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto dequantize_matcher = auto dequantize_matcher =
make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize"); make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
this->add_matcher(dequantize_matcher, constant_dequantize_callback); this->add_matcher(
dequantize_matcher, constant_dequantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class REAL, class QUANT> template <class REAL, class QUANT>
...@@ -756,5 +760,6 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -756,5 +760,6 @@ void pass::ConstantFolding::construct_constant_quantize()
auto quantize_matcher = auto quantize_matcher =
make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize"); make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize");
this->add_matcher(quantize_matcher, constant_quantize_callback); this->add_matcher(
quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
...@@ -64,27 +64,37 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -64,27 +64,37 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
const size_t NUM_TRIES = 10; const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES; size_t tries = NUM_TRIES;
vector<MatchClosure> original_matchers{m_matchers}; vector<MatchClosure> original_matchers{m_matchers};
bool is_dyn_func = f->is_dynamic();
do do
{ {
rewritten = false; rewritten = false;
// m_matchers may contain newly constructed matchers for matchers // m_matchers may contain newly constructed matchers for matchers
// that need multiple passes. See comments above. // that need multiple passes. See comments above.
vector<MatchClosure> run_matchers{m_matchers}; vector<MatchClosure> matchers_to_run{m_matchers};
m_matchers.clear(); m_matchers.clear();
for (auto node : f->get_ordered_ops()) for (auto node : f->get_ordered_ops())
{ {
for (auto& mc : run_matchers) for (auto& closure : matchers_to_run)
{ {
NGRAPH_DEBUG << "Running matcher " << mc.matcher->get_name() << "(" if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
<< mc.matcher->get_pattern()->get_name() << ") on " {
NGRAPH_DEBUG << "matcher callback requires static shape but the "
"function is dynamic, skipping this "
"optimization till the shapes are fully "
"materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << closure.matcher->get_name() << "("
<< closure.matcher->get_pattern()->get_name() << ") on "
<< node->get_name(); << node->get_name();
if (mc.matcher->match(node)) if (closure.matcher->match(node))
{ {
NGRAPH_DEBUG << "Matcher " << mc.matcher << mc.matcher->get_name() NGRAPH_DEBUG << "Matcher " << closure.matcher << closure.matcher->get_name()
<< " matched " << node->get_name(); << " matched " << node->get_name();
if (mc.callback(*mc.matcher.get())) if (closure.callback(*closure.matcher.get()))
{ {
rewritten = true; rewritten = true;
is_dyn_func = f->is_dynamic();
break; break;
} }
} }
...@@ -132,44 +142,78 @@ bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const ...@@ -132,44 +142,78 @@ bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const
} }
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback) const graph_rewrite_callback& callback,
const PassPropertyMask& property)
{ {
if (is_enabled(m)) if (is_enabled(m))
{ {
m_matchers.push_back({m, callback}); m_matchers.push_back({m, callback, property});
} }
} }
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback)
{
// TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
}
void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property)
{
m_matchers.push_back({m, callback, property});
}
void pass::RecurrentGraphRewrite::add_matcher( void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m, const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback) const ngraph::recurrent_graph_rewrite_callback& callback)
{ {
m_matchers.push_back({m, callback}); // TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
} }
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
bool changed = false; bool changed = false;
size_t i = 0; size_t i = 0;
do bool is_dyn_func = f->is_dynamic();
{
auto run_matchers = [&]() -> bool {
for (auto node : f->get_ops()) for (auto node : f->get_ops())
{ {
for (auto& mc : m_matchers) for (auto& closure : m_matchers)
{
if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
{ {
NGRAPH_DEBUG << "Running matcher " << mc.matcher << " on " << node->get_name(); NGRAPH_DEBUG << "matcher callback requires static shape but the "
if (mc.matcher->match(node)) "function is dynamic, skipping this "
"optimization till the shapes are fully "
"materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << closure.matcher << " on " << node->get_name();
if (closure.matcher->match(node))
{ {
NGRAPH_DEBUG << "Matcher " << mc.matcher << " matched " << node->get_name(); NGRAPH_DEBUG << "Matcher " << closure.matcher << " matched "
if (mc.callback(*mc.matcher.get())) << node->get_name();
if (closure.callback(*closure.matcher.get()))
{ {
changed = true; is_dyn_func = f->is_dynamic();
goto next_fusion; return true;
} }
} }
} }
} }
next_fusion: return false;
};
do
{
changed = run_matchers();
i++; i++;
} while (changed && i < m_num_iters); } while (changed && i < m_num_iters);
return changed; return changed;
......
...@@ -52,10 +52,19 @@ public: ...@@ -52,10 +52,19 @@ public:
GraphRewrite() GraphRewrite()
: FunctionPass() : FunctionPass()
{ {
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
} }
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback); const ngraph::graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
protected: protected:
...@@ -66,6 +75,7 @@ private: ...@@ -66,6 +75,7 @@ private:
{ {
std::shared_ptr<pattern::Matcher> matcher; std::shared_ptr<pattern::Matcher> matcher;
ngraph::graph_rewrite_callback callback; ngraph::graph_rewrite_callback callback;
PassPropertyMask property;
}; };
std::vector<MatchClosure> m_matchers; std::vector<MatchClosure> m_matchers;
}; };
...@@ -77,10 +87,19 @@ public: ...@@ -77,10 +87,19 @@ public:
: FunctionPass() : FunctionPass()
, m_num_iters(num_iters) , m_num_iters(num_iters)
{ {
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
} }
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m, void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback); const ngraph::recurrent_graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private: private:
...@@ -90,6 +109,7 @@ private: ...@@ -90,6 +109,7 @@ private:
{ {
std::shared_ptr<pattern::RecurrentMatcher> matcher; std::shared_ptr<pattern::RecurrentMatcher> matcher;
ngraph::recurrent_graph_rewrite_callback callback; ngraph::recurrent_graph_rewrite_callback callback;
PassPropertyMask property;
}; };
std::vector<MatchClosure> m_matchers; std::vector<MatchClosure> m_matchers;
}; };
...@@ -111,7 +111,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -111,7 +111,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
// This checks is to skip the graph optimization when the graph pass relies on static shape // This checks is to skip the graph optimization when the graph pass relies on static shape
// but the function state is dynamic. // but the function state is dynamic.
// we update the function dynamic state only if we run the graph pass successfully. // we update the function dynamic state only if we run the graph pass successfully.
if (function_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) && if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second) f_pair.second)
{ {
continue; continue;
...@@ -136,7 +136,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -136,7 +136,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
for (auto f_pair : fs) for (auto f_pair : fs)
{ {
shared_ptr<Function> f = f_pair.first; shared_ptr<Function> f = f_pair.first;
if (call_graph_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) && if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second) f_pair.second)
{ {
continue; continue;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment