Commit d18a9faf authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Switch to using has_class for trivial op::Skip predicates (#1148)

* switch to using has_class for op::Skip

* apply format
parent 1ebf4e6a
...@@ -47,11 +47,8 @@ static std::shared_ptr<pattern::Matcher> ...@@ -47,11 +47,8 @@ static std::shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label, create_binary_matcher(std::shared_ptr<pattern::op::Label> label,
std::shared_ptr<pattern::op::Label> const_label) std::shared_ptr<pattern::op::Label> const_label)
{ {
auto bcst_pred = [](std::shared_ptr<Node> n) { auto bcst =
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr; std::make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
};
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = auto matcher =
std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr); std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr);
...@@ -86,11 +83,8 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -86,11 +83,8 @@ static bool simplify_concat(std::shared_ptr<Node> n)
auto slice = auto slice =
std::make_shared<op::Slice>(lgoe, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1}); std::make_shared<op::Slice>(lgoe, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto reshape_pred = [](std::shared_ptr<Node> r) { auto skip_reshape =
return std::dynamic_pointer_cast<op::Reshape>(r) != nullptr; std::make_shared<pattern::op::Skip>(slice, pattern::has_class<op::Reshape>());
};
auto skip_reshape = std::make_shared<pattern::op::Skip>(slice, reshape_pred);
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr); auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr);
......
...@@ -55,10 +55,8 @@ void pass::CoreFusion::construct_relu() ...@@ -55,10 +55,8 @@ void pass::CoreFusion::construct_relu()
auto val = make_shared<pattern::op::Label>(iconst0); auto val = make_shared<pattern::op::Label>(iconst0);
auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0}); auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
auto broadcast_pred = [](std::shared_ptr<Node> n) { auto skip_broadcast =
return static_cast<bool>(std::dynamic_pointer_cast<op::Broadcast>(n)); std::make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
};
auto skip_broadcast = std::make_shared<pattern::op::Skip>(zero, broadcast_pred);
auto max = make_shared<op::Maximum>(skip_broadcast, val); auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
......
...@@ -175,9 +175,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul() ...@@ -175,9 +175,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w); auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x); auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) { auto reshape_pred = pattern::has_class<op::Reshape>();
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Skip>(W, reshape_pred); auto skip_w = std::make_shared<pattern::op::Skip>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Skip>(x, reshape_pred); auto skip_x = std::make_shared<pattern::op::Skip>(x, reshape_pred);
......
...@@ -609,10 +609,9 @@ static std::shared_ptr<Node> ...@@ -609,10 +609,9 @@ static std::shared_ptr<Node>
void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop() void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
{ {
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100}); auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
auto slice_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Slice>(n)); auto src_slice =
}; std::make_shared<pattern::op::Skip>(src_layer_label, pattern::has_class<op::Slice>());
auto src_slice = std::make_shared<pattern::op::Skip>(src_layer_label, slice_pred);
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100}); auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100}); auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment