Commit d18a9faf authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Switch to using has_class for trivial op::Skip predicates (#1148)

* switch to using has_class for op::Skip

* apply format
parent 1ebf4e6a
......@@ -47,11 +47,8 @@ static std::shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label,
std::shared_ptr<pattern::op::Label> const_label)
{
auto bcst_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto bcst =
std::make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher =
std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst_label), nullptr);
......@@ -86,11 +83,8 @@ static bool simplify_concat(std::shared_ptr<Node> n)
auto slice =
std::make_shared<op::Slice>(lgoe, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto reshape_pred = [](std::shared_ptr<Node> r) {
return std::dynamic_pointer_cast<op::Reshape>(r) != nullptr;
};
auto skip_reshape = std::make_shared<pattern::op::Skip>(slice, reshape_pred);
auto skip_reshape =
std::make_shared<pattern::op::Skip>(slice, pattern::has_class<op::Reshape>());
auto matcher = std::make_shared<pattern::Matcher>(skip_reshape, nullptr);
......
......@@ -55,10 +55,8 @@ void pass::CoreFusion::construct_relu()
auto val = make_shared<pattern::op::Label>(iconst0);
auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Broadcast>(n));
};
auto skip_broadcast = std::make_shared<pattern::op::Skip>(zero, broadcast_pred);
auto skip_broadcast =
std::make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
......
......@@ -175,9 +175,7 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto reshape_pred = pattern::has_class<op::Reshape>();
auto skip_w = std::make_shared<pattern::op::Skip>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Skip>(x, reshape_pred);
......
......@@ -609,10 +609,9 @@ static std::shared_ptr<Node>
void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_fusion_fprop()
{
auto src_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{30, 100});
auto slice_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Slice>(n));
};
auto src_slice = std::make_shared<pattern::op::Skip>(src_layer_label, slice_pred);
auto src_slice =
std::make_shared<pattern::op::Skip>(src_layer_label, pattern::has_class<op::Slice>());
auto src_iter_label = std::make_shared<pattern::op::Label>(element::f32, Shape{20, 100});
auto weights_layer_label = std::make_shared<pattern::op::Label>(element::f32, Shape{400, 100});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment