Commit b2ca3e79 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

[Dynamic Shape] Added Pass Properties to Core Passes (#2935)

* constexpr ctor for EnumMask

* added pass properties to core passes.

* added unit tests.

* minor fixes.
parent 4fb4be5e
......@@ -33,9 +33,7 @@ public:
AlgebraicSimplification()
: FunctionPass()
{
PassPropertyMask property{PassProperty::REGULAR_FUSIONS,
PassProperty::REQUIRE_STATIC_SHAPE};
set_property(property, true);
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -29,6 +29,7 @@ namespace ngraph
: FunctionPass()
, m_fusion_type(type)
{
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
......
......@@ -114,7 +114,7 @@ void pass::ConcatElimination::construct_concat_elimination()
};
auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
......
......@@ -47,6 +47,7 @@ private:
class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
{
public:
SelfConcatFusion() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
......
......@@ -83,7 +83,7 @@ void pass::CoreFusion::construct_relu()
};
auto m = make_shared<pattern::Matcher>(max, "CoreFusion.Relu");
this->add_matcher(m, callback);
this->add_matcher(m, callback, all_pass_property_off);
}
void pass::CoreFusion::construct_sigmoid()
......@@ -131,7 +131,7 @@ void pass::CoreFusion::construct_sigmoid()
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "CoreFusion.Sigmoid");
this->add_matcher(m, callback);
this->add_matcher(m, callback, all_pass_property_off);
}
void pass::CoreFusion::construct_sigmoid_bprop()
......@@ -183,7 +183,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
};
auto m = std::make_shared<ngraph::pattern::Matcher>(negative_2, "CoreFusion.SigmoidBprop");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::CoreFusion::construct_folded_batch_norm()
......@@ -263,7 +263,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::CoreFusion::construct_conv_affine_folding()
......@@ -373,7 +373,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
};
auto m = make_shared<pattern::Matcher>(add, "CoreFusion.ConvAffineFolding");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv,
......@@ -503,7 +503,7 @@ void pass::CoreFusion::construct_reshape_broadcast()
};
auto m = make_shared<pattern::Matcher>(broadcast, "CoreFusion.ReshapeBroadcast");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
// conv(56w3s1) conv(28w3s2)
......@@ -691,7 +691,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
};
auto m = make_shared<pattern::Matcher>(eltwise_conv, "CoreFusion.OptimizedStridedConv");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::CoreFusion::construct_reshape_softmax_reshape()
......@@ -738,7 +738,7 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
};
auto m = make_shared<pattern::Matcher>(reshape2, "CoreFusion.ReshapeSoftmaxReshape");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void ngraph::pass::CoreFusion::construct_conv_bias()
......@@ -814,7 +814,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, "CoreFusion.ConvBias");
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void ngraph::pass::CoreFusion::construct_conv_bias_add()
......@@ -864,5 +864,5 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
};
auto m = std::make_shared<pattern::Matcher>(padd, "CoreFusion.ConvBiasAdd");
this->add_matcher(m, callback);
this->add_matcher(m, callback, all_pass_property_off);
}
......@@ -32,6 +32,7 @@ public:
CommonSubexpressionElimination()
: FunctionPass()
{
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
}
CommonSubexpressionElimination(
......@@ -41,6 +42,7 @@ public:
: FunctionPass()
, m_backend_cse_handlers(backend_cse_handlers)
{
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
}
std::unordered_map<std::type_index,
......
......@@ -94,7 +94,12 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
if (closure.callback(*closure.matcher.get()))
{
rewritten = true;
is_dyn_func = f->is_dynamic();
// If call back may change function's is_dynamic state, we need to
// update the cached value.
if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
is_dyn_func = f->is_dynamic();
}
break;
}
}
......@@ -148,6 +153,12 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
if (is_enabled(m))
{
m_matchers.push_back({m, callback, property});
// If any matcher call back may change dynamic state, we need to
// update the pass property.
if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
}
}
}
......@@ -165,6 +176,12 @@ void pass::RecurrentGraphRewrite::add_matcher(
const PassPropertyMask& property)
{
m_matchers.push_back({m, callback, property});
// If any matcher call back may change dynamic state, we need to
// update the pass property.
if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
}
}
void pass::RecurrentGraphRewrite::add_matcher(
......@@ -202,7 +219,12 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
<< node->get_name();
if (closure.callback(*closure.matcher.get()))
{
is_dyn_func = f->is_dynamic();
// If call back may change function's is_dynamic state, we need to
// update the cached value.
if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
is_dyn_func = f->is_dynamic();
}
return true;
}
}
......
......@@ -52,6 +52,7 @@ public:
GraphRewrite()
: FunctionPass()
{
// Being explicit:
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
......@@ -87,6 +88,7 @@ public:
: FunctionPass()
, m_num_iters(num_iters)
{
// Being explicit:
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
......
......@@ -117,7 +117,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
continue;
}
bool function_modified = function_pass->run_on_function(f);
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second;
// If the pass may change the function's is_dynamic property, we need to
// update the cached value.
if (function_modified &&
function_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
{
f_pair.second = f->is_dynamic();
}
}
}
else if (node_pass)
......
......@@ -25,6 +25,7 @@ namespace ngraph
class NopElimination : public FunctionPass
{
public:
NopElimination() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
......
......@@ -21,8 +21,8 @@ using namespace std;
using namespace ngraph;
pass::PassBase::PassBase()
: m_property{all_pass_property_off}
{
set_property(PassProperty::REGULAR_FUSIONS, true);
}
pass::ManagerState& pass::PassBase::get_state()
......
......@@ -48,11 +48,13 @@ namespace ngraph
};
enum class PassProperty : uint32_t
{
REGULAR_FUSIONS = 1 << 1,
REQUIRE_STATIC_SHAPE = 1 << 2,
CHANGE_FUNCTION_STATE = 1 << 3
// Pass requires node shapes to be static
REQUIRE_STATIC_SHAPE = 0x1,
// Pass transformation will change the function's dynamic state
CHANGE_DYNAMIC_STATE = 1 << 1
};
typedef EnumMask<PassProperty> PassPropertyMask;
constexpr PassPropertyMask all_pass_property_off;
}
}
......
......@@ -81,5 +81,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true;
};
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"), callback);
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"),
callback,
PassProperty::REQUIRE_STATIC_SHAPE);
}
......@@ -73,7 +73,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
};
auto m = make_shared<pattern::Matcher>(reshape1);
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ReshapeElimination::construct_reshapex2_pattern()
......@@ -132,7 +132,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
return false;
};
auto m = make_shared<pattern::Matcher>(reshape2);
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ReshapeElimination::construct_dot_transpose_pattern()
......@@ -189,7 +189,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
};
auto m = make_shared<pattern::Matcher>(preshape);
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
......@@ -291,5 +291,5 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m =
std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches);
this->add_matcher(m, callback);
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
......@@ -26,6 +26,7 @@ namespace ngraph
class ReshapeSinking : public ngraph::pass::FunctionPass
{
public:
ReshapeSinking() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
......
......@@ -28,6 +28,7 @@ namespace ngraph
ShapeSpecialization()
: FunctionPass()
{
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
......
......@@ -32,6 +32,7 @@ public:
ZeroDimTensorElimination()
: FunctionPass()
{
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
......
......@@ -267,11 +267,11 @@ namespace ngraph
/// type to use unsigned underlying type.
static_assert(std::is_unsigned<value_type>::value, "EnumMask enum must use unsigned type.");
EnumMask()
constexpr EnumMask()
: m_value{0}
{
}
EnumMask(const T& enum_value)
constexpr EnumMask(const T& enum_value)
: m_value{static_cast<value_type>(enum_value)}
{
}
......@@ -288,13 +288,13 @@ namespace ngraph
}
}
value_type value() const { return m_value; }
/// Check if any of the enum bit mask match
/// Check if any of the input parameter enum bit mask match
bool is_any_set(const EnumMask& p) const { return m_value & p.m_value; }
/// Check if all of the enum bit mask match
/// Check if all of the input parameter enum bit mask match
bool is_set(const EnumMask& p) const { return (m_value & p.m_value) == p.m_value; }
/// Check if any of the enum bit mask does not match
/// Check if any of the input parameter enum bit mask does not match
bool is_any_clear(const EnumMask& p) const { return !is_set(p); }
/// Check if all of the enum bit mask do not match
/// Check if all of the input parameter enum bit mask do not match
bool is_clear(const EnumMask& p) const { return !is_any_set(p); }
void set(const EnumMask& p) { m_value |= p.m_value; }
void clear(const EnumMask& p) { m_value &= ~p.m_value; }
......
......@@ -590,7 +590,5 @@ TEST(algebraic_simplification, pass_property)
auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true,
pass->get_property(pass::PassPropertyMask(pass::PassProperty::REGULAR_FUSIONS) |
pass::PassPropertyMask(pass::PassProperty::REQUIRE_STATIC_SHAPE)));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -292,3 +292,18 @@ TEST(concat_fusion, self_concat_with_fan_out)
ASSERT_EQ(num_reshapes_optimized, 1);
ASSERT_EQ(num_broadcast_optimzed, 1);
}
TEST(concat_fusion, pass_property)
{
{
auto pass = std::make_shared<ngraph::pass::ConcatElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
{
auto pass = std::make_shared<ngraph::pass::SelfConcatFusion>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
}
......@@ -260,3 +260,9 @@ TEST(constant_folding, const_quantize)
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
ASSERT_EQ(values_quantize, values_out);
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -525,3 +525,17 @@ TEST(batch_fusion, group_convolution_fusion)
std::dynamic_pointer_cast<op::GroupConvolution>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(gc);
}
TEST(core_fusion, pass_property)
{
auto pass = std::make_shared<ngraph::pass::CoreFusion>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
TEST(batch_fusion, pass_property)
{
auto pass = std::make_shared<ngraph::pass::BatchFusion>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -334,3 +334,10 @@ TEST(CSE, one_hot)
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
}
TEST(CSE, pass_property)
{
auto pass = std::make_shared<ngraph::pass::CommonSubexpressionElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -110,3 +110,10 @@ TEST(nop_elimination, eliminate_stop_gradient)
ASSERT_EQ(count_ops_of_type<op::StopGradient>(f), 0);
}
TEST(nop_elimination, pass_property)
{
auto pass = std::make_shared<ngraph::pass::NopElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -185,3 +185,10 @@ TEST(shape_specialization, specialization_pass_add_concat_transpose)
ASSERT_EQ(constant_after->get_element_type(), element::i64);
ASSERT_EQ(constant_after->get_vector<int64_t>(), (vector<int64_t>{1, 0}));
}
TEST(shape_specialization, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ShapeSpecialization>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -433,3 +433,17 @@ TEST(reshape_elimination, recurrent_reshapes_multiple_branches)
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2);
}
TEST(reshape_elimination, pass_property)
{
{
auto pass = std::make_shared<ngraph::pass::ReshapeElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
{
auto pass = std::make_shared<ngraph::pass::RecurrentReshapeElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
}
......@@ -275,3 +275,10 @@ TEST(reshape_sinking, concat)
size_t before_after = count_ops_of_type<op::Reshape>(f);
ASSERT_LE(before_after, before_count);
}
TEST(reshape_sinking, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ReshapeSinking>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
......@@ -191,3 +191,10 @@ TEST(zero_dim_tensor_elimination, zero_const_slice)
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
}
TEST(zero_dim_tensor_elimination, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ZeroDimTensorElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment