Commit b2ca3e79 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

[Dynamic Shape] Added Pass Properties to Core Passes (#2935)

* constexpr ctor for EnumMask

* added pass properties to core passes.

* added unit tests.

* minor fixes.
parent 4fb4be5e
...@@ -33,9 +33,7 @@ public: ...@@ -33,9 +33,7 @@ public:
AlgebraicSimplification() AlgebraicSimplification()
: FunctionPass() : FunctionPass()
{ {
PassPropertyMask property{PassProperty::REGULAR_FUSIONS, set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
PassProperty::REQUIRE_STATIC_SHAPE};
set_property(property, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
}; };
...@@ -29,6 +29,7 @@ namespace ngraph ...@@ -29,6 +29,7 @@ namespace ngraph
: FunctionPass() : FunctionPass()
, m_fusion_type(type) , m_fusion_type(type)
{ {
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override; virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
......
...@@ -114,7 +114,7 @@ void pass::ConcatElimination::construct_concat_elimination() ...@@ -114,7 +114,7 @@ void pass::ConcatElimination::construct_concat_elimination()
}; };
auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination"); auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function) bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
......
...@@ -47,6 +47,7 @@ private: ...@@ -47,6 +47,7 @@ private:
class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
{ {
public: public:
SelfConcatFusion() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override; virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private: private:
......
...@@ -83,7 +83,7 @@ void pass::CoreFusion::construct_relu() ...@@ -83,7 +83,7 @@ void pass::CoreFusion::construct_relu()
}; };
auto m = make_shared<pattern::Matcher>(max, "CoreFusion.Relu"); auto m = make_shared<pattern::Matcher>(max, "CoreFusion.Relu");
this->add_matcher(m, callback); this->add_matcher(m, callback, all_pass_property_off);
} }
void pass::CoreFusion::construct_sigmoid() void pass::CoreFusion::construct_sigmoid()
...@@ -131,7 +131,7 @@ void pass::CoreFusion::construct_sigmoid() ...@@ -131,7 +131,7 @@ void pass::CoreFusion::construct_sigmoid()
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "CoreFusion.Sigmoid"); auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "CoreFusion.Sigmoid");
this->add_matcher(m, callback); this->add_matcher(m, callback, all_pass_property_off);
} }
void pass::CoreFusion::construct_sigmoid_bprop() void pass::CoreFusion::construct_sigmoid_bprop()
...@@ -183,7 +183,7 @@ void pass::CoreFusion::construct_sigmoid_bprop() ...@@ -183,7 +183,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(negative_2, "CoreFusion.SigmoidBprop"); auto m = std::make_shared<ngraph::pattern::Matcher>(negative_2, "CoreFusion.SigmoidBprop");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::CoreFusion::construct_folded_batch_norm() void pass::CoreFusion::construct_folded_batch_norm()
...@@ -263,7 +263,7 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -263,7 +263,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "CoreFusion.FoldedBatchNorm"); auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::CoreFusion::construct_conv_affine_folding() void pass::CoreFusion::construct_conv_affine_folding()
...@@ -373,7 +373,7 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -373,7 +373,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
}; };
auto m = make_shared<pattern::Matcher>(add, "CoreFusion.ConvAffineFolding"); auto m = make_shared<pattern::Matcher>(add, "CoreFusion.ConvAffineFolding");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv, static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv,
...@@ -503,7 +503,7 @@ void pass::CoreFusion::construct_reshape_broadcast() ...@@ -503,7 +503,7 @@ void pass::CoreFusion::construct_reshape_broadcast()
}; };
auto m = make_shared<pattern::Matcher>(broadcast, "CoreFusion.ReshapeBroadcast"); auto m = make_shared<pattern::Matcher>(broadcast, "CoreFusion.ReshapeBroadcast");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
// conv(56w3s1) conv(28w3s2) // conv(56w3s1) conv(28w3s2)
...@@ -691,7 +691,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -691,7 +691,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
}; };
auto m = make_shared<pattern::Matcher>(eltwise_conv, "CoreFusion.OptimizedStridedConv"); auto m = make_shared<pattern::Matcher>(eltwise_conv, "CoreFusion.OptimizedStridedConv");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::CoreFusion::construct_reshape_softmax_reshape() void pass::CoreFusion::construct_reshape_softmax_reshape()
...@@ -738,7 +738,7 @@ void pass::CoreFusion::construct_reshape_softmax_reshape() ...@@ -738,7 +738,7 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
}; };
auto m = make_shared<pattern::Matcher>(reshape2, "CoreFusion.ReshapeSoftmaxReshape"); auto m = make_shared<pattern::Matcher>(reshape2, "CoreFusion.ReshapeSoftmaxReshape");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void ngraph::pass::CoreFusion::construct_conv_bias() void ngraph::pass::CoreFusion::construct_conv_bias()
...@@ -814,7 +814,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias() ...@@ -814,7 +814,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, "CoreFusion.ConvBias"); auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, "CoreFusion.ConvBias");
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void ngraph::pass::CoreFusion::construct_conv_bias_add() void ngraph::pass::CoreFusion::construct_conv_bias_add()
...@@ -864,5 +864,5 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add() ...@@ -864,5 +864,5 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
}; };
auto m = std::make_shared<pattern::Matcher>(padd, "CoreFusion.ConvBiasAdd"); auto m = std::make_shared<pattern::Matcher>(padd, "CoreFusion.ConvBiasAdd");
this->add_matcher(m, callback); this->add_matcher(m, callback, all_pass_property_off);
} }
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
CommonSubexpressionElimination() CommonSubexpressionElimination()
: FunctionPass() : FunctionPass()
{ {
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
} }
CommonSubexpressionElimination( CommonSubexpressionElimination(
...@@ -41,6 +42,7 @@ public: ...@@ -41,6 +42,7 @@ public:
: FunctionPass() : FunctionPass()
, m_backend_cse_handlers(backend_cse_handlers) , m_backend_cse_handlers(backend_cse_handlers)
{ {
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
} }
std::unordered_map<std::type_index, std::unordered_map<std::type_index,
......
...@@ -94,7 +94,12 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -94,7 +94,12 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
if (closure.callback(*closure.matcher.get())) if (closure.callback(*closure.matcher.get()))
{ {
rewritten = true; rewritten = true;
// If call back may change function's is_dynamic state, we need to
// update the cached value.
if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
is_dyn_func = f->is_dynamic(); is_dyn_func = f->is_dynamic();
}
break; break;
} }
} }
...@@ -148,6 +153,12 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, ...@@ -148,6 +153,12 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
if (is_enabled(m)) if (is_enabled(m))
{ {
m_matchers.push_back({m, callback, property}); m_matchers.push_back({m, callback, property});
// If any matcher call back may change dynamic state, we need to
// update the pass property.
if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
}
} }
} }
...@@ -165,6 +176,12 @@ void pass::RecurrentGraphRewrite::add_matcher( ...@@ -165,6 +176,12 @@ void pass::RecurrentGraphRewrite::add_matcher(
const PassPropertyMask& property) const PassPropertyMask& property)
{ {
m_matchers.push_back({m, callback, property}); m_matchers.push_back({m, callback, property});
// If any matcher call back may change dynamic state, we need to
// update the pass property.
if (property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
}
} }
void pass::RecurrentGraphRewrite::add_matcher( void pass::RecurrentGraphRewrite::add_matcher(
...@@ -201,8 +218,13 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -201,8 +218,13 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
NGRAPH_DEBUG << "Matcher " << closure.matcher << " matched " NGRAPH_DEBUG << "Matcher " << closure.matcher << " matched "
<< node->get_name(); << node->get_name();
if (closure.callback(*closure.matcher.get())) if (closure.callback(*closure.matcher.get()))
{
// If call back may change function's is_dynamic state, we need to
// update the cached value.
if (closure.property.is_set(PassProperty::CHANGE_DYNAMIC_STATE))
{ {
is_dyn_func = f->is_dynamic(); is_dyn_func = f->is_dynamic();
}
return true; return true;
} }
} }
......
...@@ -52,6 +52,7 @@ public: ...@@ -52,6 +52,7 @@ public:
GraphRewrite() GraphRewrite()
: FunctionPass() : FunctionPass()
{ {
// Being explicit:
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each // Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function(). // callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false); set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
...@@ -87,6 +88,7 @@ public: ...@@ -87,6 +88,7 @@ public:
: FunctionPass() : FunctionPass()
, m_num_iters(num_iters) , m_num_iters(num_iters)
{ {
// Being explicit:
// Setting REQUIRE_STATIC_SHAPE to false because we will check if each // Setting REQUIRE_STATIC_SHAPE to false because we will check if each
// callback needs static shape during run_on_function(). // callback needs static shape during run_on_function().
set_property(PassProperty::REQUIRE_STATIC_SHAPE, false); set_property(PassProperty::REQUIRE_STATIC_SHAPE, false);
......
...@@ -117,7 +117,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive) ...@@ -117,7 +117,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
continue; continue;
} }
bool function_modified = function_pass->run_on_function(f); bool function_modified = function_pass->run_on_function(f);
f_pair.second = (function_modified == true) ? f->is_dynamic() : f_pair.second; // If the pass may change the function's is_dynamic property, we need to
// update the cached value.
if (function_modified &&
function_pass->get_property(PassProperty::CHANGE_DYNAMIC_STATE))
{
f_pair.second = f->is_dynamic();
}
} }
} }
else if (node_pass) else if (node_pass)
......
...@@ -25,6 +25,7 @@ namespace ngraph ...@@ -25,6 +25,7 @@ namespace ngraph
class NopElimination : public FunctionPass class NopElimination : public FunctionPass
{ {
public: public:
NopElimination() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
bool run_on_function(std::shared_ptr<ngraph::Function> function) override; bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
}; };
} }
......
...@@ -21,8 +21,8 @@ using namespace std; ...@@ -21,8 +21,8 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
pass::PassBase::PassBase() pass::PassBase::PassBase()
: m_property{all_pass_property_off}
{ {
set_property(PassProperty::REGULAR_FUSIONS, true);
} }
pass::ManagerState& pass::PassBase::get_state() pass::ManagerState& pass::PassBase::get_state()
......
...@@ -48,11 +48,13 @@ namespace ngraph ...@@ -48,11 +48,13 @@ namespace ngraph
}; };
enum class PassProperty : uint32_t enum class PassProperty : uint32_t
{ {
REGULAR_FUSIONS = 1 << 1, // Pass requires node shapes to be static
REQUIRE_STATIC_SHAPE = 1 << 2, REQUIRE_STATIC_SHAPE = 0x1,
CHANGE_FUNCTION_STATE = 1 << 3 // Pass transformation will change the function's dynamic state
CHANGE_DYNAMIC_STATE = 1 << 1
}; };
typedef EnumMask<PassProperty> PassPropertyMask; typedef EnumMask<PassProperty> PassPropertyMask;
constexpr PassPropertyMask all_pass_property_off;
} }
} }
......
...@@ -81,5 +81,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -81,5 +81,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2)); replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true; return true;
}; };
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"), callback); add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"),
callback,
PassProperty::REQUIRE_STATIC_SHAPE);
} }
...@@ -73,7 +73,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -73,7 +73,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
}; };
auto m = make_shared<pattern::Matcher>(reshape1); auto m = make_shared<pattern::Matcher>(reshape1);
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::ReshapeElimination::construct_reshapex2_pattern() void pass::ReshapeElimination::construct_reshapex2_pattern()
...@@ -132,7 +132,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -132,7 +132,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
return false; return false;
}; };
auto m = make_shared<pattern::Matcher>(reshape2); auto m = make_shared<pattern::Matcher>(reshape2);
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::ReshapeElimination::construct_dot_transpose_pattern() void pass::ReshapeElimination::construct_dot_transpose_pattern()
...@@ -189,7 +189,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -189,7 +189,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
}; };
auto m = make_shared<pattern::Matcher>(preshape); auto m = make_shared<pattern::Matcher>(preshape);
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void pass::RecurrentReshapeElimination::construct_recurrent_reshape() void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
...@@ -291,5 +291,5 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -291,5 +291,5 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = auto m =
std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches); std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches);
this->add_matcher(m, callback); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
...@@ -26,6 +26,7 @@ namespace ngraph ...@@ -26,6 +26,7 @@ namespace ngraph
class ReshapeSinking : public ngraph::pass::FunctionPass class ReshapeSinking : public ngraph::pass::FunctionPass
{ {
public: public:
ReshapeSinking() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
bool run_on_function(std::shared_ptr<ngraph::Function> function) override; bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
}; };
} }
......
...@@ -28,6 +28,7 @@ namespace ngraph ...@@ -28,6 +28,7 @@ namespace ngraph
ShapeSpecialization() ShapeSpecialization()
: FunctionPass() : FunctionPass()
{ {
set_property(PassProperty::CHANGE_DYNAMIC_STATE, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override; virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
}; };
......
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
ZeroDimTensorElimination() ZeroDimTensorElimination()
: FunctionPass() : FunctionPass()
{ {
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
} }
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
......
...@@ -267,11 +267,11 @@ namespace ngraph ...@@ -267,11 +267,11 @@ namespace ngraph
/// type to use unsigned underlying type. /// type to use unsigned underlying type.
static_assert(std::is_unsigned<value_type>::value, "EnumMask enum must use unsigned type."); static_assert(std::is_unsigned<value_type>::value, "EnumMask enum must use unsigned type.");
EnumMask() constexpr EnumMask()
: m_value{0} : m_value{0}
{ {
} }
EnumMask(const T& enum_value) constexpr EnumMask(const T& enum_value)
: m_value{static_cast<value_type>(enum_value)} : m_value{static_cast<value_type>(enum_value)}
{ {
} }
...@@ -288,13 +288,13 @@ namespace ngraph ...@@ -288,13 +288,13 @@ namespace ngraph
} }
} }
value_type value() const { return m_value; } value_type value() const { return m_value; }
/// Check if any of the enum bit mask match /// Check if any of the input parameter enum bit mask match
bool is_any_set(const EnumMask& p) const { return m_value & p.m_value; } bool is_any_set(const EnumMask& p) const { return m_value & p.m_value; }
/// Check if all of the enum bit mask match /// Check if all of the input parameter enum bit mask match
bool is_set(const EnumMask& p) const { return (m_value & p.m_value) == p.m_value; } bool is_set(const EnumMask& p) const { return (m_value & p.m_value) == p.m_value; }
/// Check if any of the enum bit mask does not match /// Check if any of the input parameter enum bit mask does not match
bool is_any_clear(const EnumMask& p) const { return !is_set(p); } bool is_any_clear(const EnumMask& p) const { return !is_set(p); }
/// Check if all of the enum bit mask do not match /// Check if all of the input parameter enum bit mask do not match
bool is_clear(const EnumMask& p) const { return !is_any_set(p); } bool is_clear(const EnumMask& p) const { return !is_any_set(p); }
void set(const EnumMask& p) { m_value |= p.m_value; } void set(const EnumMask& p) { m_value |= p.m_value; }
void clear(const EnumMask& p) { m_value &= ~p.m_value; } void clear(const EnumMask& p) { m_value &= ~p.m_value; }
......
...@@ -590,7 +590,5 @@ TEST(algebraic_simplification, pass_property) ...@@ -590,7 +590,5 @@ TEST(algebraic_simplification, pass_property)
auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>(); auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)); ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true, ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
pass->get_property(pass::PassPropertyMask(pass::PassProperty::REGULAR_FUSIONS) |
pass::PassPropertyMask(pass::PassProperty::REQUIRE_STATIC_SHAPE)));
} }
...@@ -292,3 +292,18 @@ TEST(concat_fusion, self_concat_with_fan_out) ...@@ -292,3 +292,18 @@ TEST(concat_fusion, self_concat_with_fan_out)
ASSERT_EQ(num_reshapes_optimized, 1); ASSERT_EQ(num_reshapes_optimized, 1);
ASSERT_EQ(num_broadcast_optimzed, 1); ASSERT_EQ(num_broadcast_optimzed, 1);
} }
TEST(concat_fusion, pass_property)
{
{
auto pass = std::make_shared<ngraph::pass::ConcatElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
{
auto pass = std::make_shared<ngraph::pass::SelfConcatFusion>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
}
...@@ -260,3 +260,9 @@ TEST(constant_folding, const_quantize) ...@@ -260,3 +260,9 @@ TEST(constant_folding, const_quantize)
vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5}; vector<output_c_type> values_quantize{2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5};
ASSERT_EQ(values_quantize, values_out); ASSERT_EQ(values_quantize, values_out);
} }
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -525,3 +525,17 @@ TEST(batch_fusion, group_convolution_fusion) ...@@ -525,3 +525,17 @@ TEST(batch_fusion, group_convolution_fusion)
std::dynamic_pointer_cast<op::GroupConvolution>(f->get_results().at(0)->get_argument(0)); std::dynamic_pointer_cast<op::GroupConvolution>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(gc); ASSERT_TRUE(gc);
} }
TEST(core_fusion, pass_property)
{
auto pass = std::make_shared<ngraph::pass::CoreFusion>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
TEST(batch_fusion, pass_property)
{
auto pass = std::make_shared<ngraph::pass::BatchFusion>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -334,3 +334,10 @@ TEST(CSE, one_hot) ...@@ -334,3 +334,10 @@ TEST(CSE, one_hot)
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0)); ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
} }
} }
TEST(CSE, pass_property)
{
auto pass = std::make_shared<ngraph::pass::CommonSubexpressionElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -110,3 +110,10 @@ TEST(nop_elimination, eliminate_stop_gradient) ...@@ -110,3 +110,10 @@ TEST(nop_elimination, eliminate_stop_gradient)
ASSERT_EQ(count_ops_of_type<op::StopGradient>(f), 0); ASSERT_EQ(count_ops_of_type<op::StopGradient>(f), 0);
} }
TEST(nop_elimination, pass_property)
{
auto pass = std::make_shared<ngraph::pass::NopElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -185,3 +185,10 @@ TEST(shape_specialization, specialization_pass_add_concat_transpose) ...@@ -185,3 +185,10 @@ TEST(shape_specialization, specialization_pass_add_concat_transpose)
ASSERT_EQ(constant_after->get_element_type(), element::i64); ASSERT_EQ(constant_after->get_element_type(), element::i64);
ASSERT_EQ(constant_after->get_vector<int64_t>(), (vector<int64_t>{1, 0})); ASSERT_EQ(constant_after->get_vector<int64_t>(), (vector<int64_t>{1, 0}));
} }
TEST(shape_specialization, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ShapeSpecialization>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(true, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -433,3 +433,17 @@ TEST(reshape_elimination, recurrent_reshapes_multiple_branches) ...@@ -433,3 +433,17 @@ TEST(reshape_elimination, recurrent_reshapes_multiple_branches)
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f); size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
ASSERT_EQ(num_reshapes_optimized, 2); ASSERT_EQ(num_reshapes_optimized, 2);
} }
TEST(reshape_elimination, pass_property)
{
{
auto pass = std::make_shared<ngraph::pass::ReshapeElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
{
auto pass = std::make_shared<ngraph::pass::RecurrentReshapeElimination>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
}
...@@ -275,3 +275,10 @@ TEST(reshape_sinking, concat) ...@@ -275,3 +275,10 @@ TEST(reshape_sinking, concat)
size_t before_after = count_ops_of_type<op::Reshape>(f); size_t before_after = count_ops_of_type<op::Reshape>(f);
ASSERT_LE(before_after, before_count); ASSERT_LE(before_after, before_count);
} }
TEST(reshape_sinking, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ReshapeSinking>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
...@@ -191,3 +191,10 @@ TEST(zero_dim_tensor_elimination, zero_const_slice) ...@@ -191,3 +191,10 @@ TEST(zero_dim_tensor_elimination, zero_const_slice)
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1); EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0); EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
} }
TEST(zero_dim_tensor_elimination, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ZeroDimTensorElimination>();
ASSERT_EQ(true, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment