Commit 391d50e0 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

[Dynamic Shape] GraphRewrite refactor with constant folding pass (#2860)

* add dynamic function support to GraphRewrite.

* add pass property to constant folding pass.

* namespace clean up.

* fixed comment.
parent 11c3ca8b
......@@ -33,9 +33,9 @@ public:
AlgebraicSimplification()
: FunctionPass()
{
pass::PassPropertyMask property{pass::PassProperty::REGULAR_FUSIONS,
pass::PassProperty::REQUIRE_STATIC_SHAPE};
ngraph::pass::PassBase::set_property(property, true);
PassPropertyMask property{PassProperty::REGULAR_FUSIONS,
PassProperty::REQUIRE_STATIC_SHAPE};
set_property(property, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -181,7 +181,7 @@ void pass::ConstantFolding::construct_constant_pad()
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback);
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ConstantFolding::construct_constant_reshape()
......@@ -246,7 +246,8 @@ void pass::ConstantFolding::construct_constant_reshape()
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(reshape_matcher, constant_reshape_callback);
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
......@@ -341,7 +342,8 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto broadcast_matcher =
make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher(broadcast_matcher, constant_broadcast_callback);
this->add_matcher(
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
......@@ -478,7 +480,8 @@ void pass::ConstantFolding::construct_constant_binary()
};
auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary");
this->add_matcher(reshape_matcher, constant_binary_callback);
this->add_matcher(
reshape_matcher, constant_binary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
bool is_supported_unary_op(std::shared_ptr<Node> n)
......@@ -608,7 +611,7 @@ void pass::ConstantFolding::construct_constant_unary()
};
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback);
this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class QUANT, class REAL>
......@@ -681,7 +684,8 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto dequantize_matcher =
make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
this->add_matcher(dequantize_matcher, constant_dequantize_callback);
this->add_matcher(
dequantize_matcher, constant_dequantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class REAL, class QUANT>
......@@ -756,5 +760,6 @@ void pass::ConstantFolding::construct_constant_quantize()
auto quantize_matcher =
make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize");
this->add_matcher(quantize_matcher, constant_quantize_callback);
this->add_matcher(
quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
......@@ -64,27 +64,37 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES;
vector<MatchClosure> original_matchers{m_matchers};
bool is_dyn_func = f->is_dynamic();
do
{
rewritten = false;
// m_matchers may contain newly constructed matchers for matchers
// that need multiple passes. See comments above.
vector<MatchClosure> run_matchers{m_matchers};
vector<MatchClosure> matchers_to_run{m_matchers};
m_matchers.clear();
for (auto node : f->get_ordered_ops())
{
for (auto& mc : run_matchers)
for (auto& closure : matchers_to_run)
{
NGRAPH_DEBUG << "Running matcher " << mc.matcher->get_name() << "("
<< mc.matcher->get_pattern()->get_name() << ") on "
if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
{
NGRAPH_DEBUG << "matcher callback requires static shape but the "
"function is dynamic, skipping this "
"optimization till the shapes are fully "
"materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << closure.matcher->get_name() << "("
<< closure.matcher->get_pattern()->get_name() << ") on "
<< node->get_name();
if (mc.matcher->match(node))
if (closure.matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << mc.matcher << mc.matcher->get_name()
NGRAPH_DEBUG << "Matcher " << closure.matcher << closure.matcher->get_name()
<< " matched " << node->get_name();
if (mc.callback(*mc.matcher.get()))
if (closure.callback(*closure.matcher.get()))
{
rewritten = true;
is_dyn_func = f->is_dynamic();
break;
}
}
......@@ -132,44 +142,78 @@ bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const
}
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback)
const graph_rewrite_callback& callback,
const PassPropertyMask& property)
{
if (is_enabled(m))
{
m_matchers.push_back({m, callback});
m_matchers.push_back({m, callback, property});
}
}
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback)
{
// TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
}
void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property)
{
m_matchers.push_back({m, callback, property});
}
void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback)
{
m_matchers.push_back({m, callback});
// TODO: before deprecate this function, by default expect the
// callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
}
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool changed = false;
size_t i = 0;
do
{
bool is_dyn_func = f->is_dynamic();
auto run_matchers = [&]() -> bool {
for (auto node : f->get_ops())
{
for (auto& mc : m_matchers)
for (auto& closure : m_matchers)
{
NGRAPH_DEBUG << "Running matcher " << mc.matcher << " on " << node->get_name();
if (mc.matcher->match(node))
if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
{
NGRAPH_DEBUG << "Matcher " << mc.matcher << " matched " << node->get_name();
if (mc.callback(*mc.matcher.get()))
NGRAPH_DEBUG << "matcher callback requires static shape but the "
"function is dynamic, skipping this "
"optimization till the shapes are fully "
"materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << closure.matcher << " on " << node->get_name();
if (closure.matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << closure.matcher << " matched "
<< node->get_name();
if (closure.callback(*closure.matcher.get()))
{
changed = true;
goto next_fusion;
is_dyn_func = f->is_dynamic();
return true;
}
}
}
}
next_fusion:
return false;
};
do
{
changed = run_matchers();
i++;
} while (changed && i < m_num_iters);
return changed;
......
......@@ -52,10 +52,19 @@ public:
GraphRewrite()
: FunctionPass()
{
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
}
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
protected:
......@@ -66,6 +75,7 @@ private:
{
std::shared_ptr<pattern::Matcher> matcher;
ngraph::graph_rewrite_callback callback;
PassPropertyMask property;
};
std::vector<MatchClosure> m_matchers;
};
......@@ -77,10 +87,19 @@ public:
: FunctionPass()
, m_num_iters(num_iters)
{
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
}
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
......@@ -90,6 +109,7 @@ private:
{
std::shared_ptr<pattern::RecurrentMatcher> matcher;
ngraph::recurrent_graph_rewrite_callback callback;
PassPropertyMask property;
};
std::vector<MatchClosure> m_matchers;
};
......@@ -111,7 +111,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
// This checks is to skip the graph optimization when the graph pass relies on static shape
// but the function state is dynamic.
// we update the function dynamic state only if we run the graph pass successfully.
if (function_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
if (function_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
......@@ -136,7 +136,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
for (auto f_pair : fs)
{
shared_ptr<Function> f = f_pair.first;
if (call_graph_pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE) &&
if (call_graph_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) &&
f_pair.second)
{
continue;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment