Commit d19ae275 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Supports more cases for Convolution Bias fusion (#2908)

* Supports more cases for Convolution Bias fusion and also removes redundant namespace qualifiers in core_fusion.cpp

* added unit test

* select all fusions
parent 72bf9831
...@@ -60,8 +60,7 @@ void pass::CoreFusion::construct_relu() ...@@ -60,8 +60,7 @@ void pass::CoreFusion::construct_relu()
auto val = make_shared<pattern::op::Label>(iconst0); auto val = make_shared<pattern::op::Label>(iconst0);
auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0}); auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
auto skip_broadcast = auto skip_broadcast = make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
std::make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
auto max = make_shared<op::Maximum>(skip_broadcast, val); auto max = make_shared<op::Maximum>(skip_broadcast, val);
auto callback = [val, zero](pattern::Matcher& m) { auto callback = [val, zero](pattern::Matcher& m) {
...@@ -89,16 +88,16 @@ void pass::CoreFusion::construct_relu() ...@@ -89,16 +88,16 @@ void pass::CoreFusion::construct_relu()
void pass::CoreFusion::construct_sigmoid() void pass::CoreFusion::construct_sigmoid()
{ {
// construct variance // construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto input = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input); auto neg_input = make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input); auto exp_neg_input = make_shared<op::Exp>(neg_input);
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto constant = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto skip_broadcast = auto skip_broadcast =
std::make_shared<pattern::op::Skip>(constant, pattern::has_class<op::Broadcast>()); make_shared<pattern::op::Skip>(constant, pattern::has_class<op::Broadcast>());
auto add_exp = std::make_shared<op::Add>(exp_neg_input, skip_broadcast); auto add_exp = make_shared<op::Add>(exp_neg_input, skip_broadcast);
auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp); auto divide_1_over_exp = make_shared<op::Divide>(skip_broadcast, add_exp);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
auto callback = [input, constant](pattern::Matcher& m) { auto callback = [input, constant](pattern::Matcher& m) {
...@@ -125,7 +124,7 @@ void pass::CoreFusion::construct_sigmoid() ...@@ -125,7 +124,7 @@ void pass::CoreFusion::construct_sigmoid()
NGRAPH_DEBUG << "Node not constant or not 1"; NGRAPH_DEBUG << "Node not constant or not 1";
return false; return false;
} }
auto sigmoid_node = std::make_shared<op::Sigmoid>(pattern_map[input]); auto sigmoid_node = make_shared<op::Sigmoid>(pattern_map[input]);
replace_node(m.get_match_root(), sigmoid_node); replace_node(m.get_match_root(), sigmoid_node);
return true; return true;
}; };
...@@ -137,26 +136,26 @@ void pass::CoreFusion::construct_sigmoid() ...@@ -137,26 +136,26 @@ void pass::CoreFusion::construct_sigmoid()
void pass::CoreFusion::construct_sigmoid_bprop() void pass::CoreFusion::construct_sigmoid_bprop()
{ {
// construct variance // construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto input = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input); auto neg_input = make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input); auto exp_neg_input = make_shared<op::Exp>(neg_input);
// broadcast input // broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto constant = make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1}); auto broadcast_constant = make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant); auto add_exp = make_shared<op::Add>(exp_neg_input, broadcast_constant);
// auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp); // auto divide_1_over_exp = make_shared<op::Divide>(broadcast_constant, add_exp);
auto sigmoid_fwd = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto sigmoid_fwd = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto delta = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4}); auto delta = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_delta = std::make_shared<op::Negative>(delta); auto neg_delta = make_shared<op::Negative>(delta);
auto multiply_sigmoid_delta = std::make_shared<op::Multiply>(sigmoid_fwd, neg_delta); auto multiply_sigmoid_delta = make_shared<op::Multiply>(sigmoid_fwd, neg_delta);
auto divide_2 = std::make_shared<op::Divide>(multiply_sigmoid_delta, add_exp); auto divide_2 = make_shared<op::Divide>(multiply_sigmoid_delta, add_exp);
auto multiply_2 = std::make_shared<op::Multiply>(divide_2, exp_neg_input); auto multiply_2 = make_shared<op::Multiply>(divide_2, exp_neg_input);
auto negative_2 = std::make_shared<op::Negative>(multiply_2); auto negative_2 = make_shared<op::Negative>(multiply_2);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
auto callback = [input, delta](pattern::Matcher& m) { auto callback = [input, delta](pattern::Matcher& m) {
...@@ -176,8 +175,7 @@ void pass::CoreFusion::construct_sigmoid_bprop() ...@@ -176,8 +175,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
<< "input= " << pattern_map[input]->get_name() << "size dont match!"; << "input= " << pattern_map[input]->get_name() << "size dont match!";
return false; return false;
} }
auto dsigmoid = auto dsigmoid = make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
replace_node(m.get_match_root(), dsigmoid); replace_node(m.get_match_root(), dsigmoid);
return true; return true;
}; };
...@@ -189,10 +187,10 @@ void pass::CoreFusion::construct_sigmoid_bprop() ...@@ -189,10 +187,10 @@ void pass::CoreFusion::construct_sigmoid_bprop()
void pass::CoreFusion::construct_folded_batch_norm() void pass::CoreFusion::construct_folded_batch_norm()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::f32, shape); auto input = make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = make_shared<pattern::op::Label>(element::f32, shape);
auto pconv = std::make_shared<op::Convolution>(input, auto pconv = make_shared<op::Convolution>(input,
filters, filters,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
...@@ -200,24 +198,24 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -200,24 +198,24 @@ void pass::CoreFusion::construct_folded_batch_norm()
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto mean_shape = Shape{2}; auto mean_shape = Shape{2};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape); auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{2}; auto var_shape = Shape{2};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape); auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{2}; auto gamma_shape = Shape{2};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape); auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2}; auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape); auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001; double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2}; auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var); auto bn = make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var);
auto callback = [input, filters, mean, var, gamma, beta](pattern::Matcher& m) { auto callback = [input, filters, mean, var, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for folded batch norm against node = " NGRAPH_DEBUG << "In callback for folded batch norm against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_bn = std::static_pointer_cast<op::BatchNormInference>(m.get_match_root()); auto m_bn = static_pointer_cast<op::BatchNormInference>(m.get_match_root());
auto m_conv = std::static_pointer_cast<op::Convolution>(m_bn->get_argument(2)); auto m_conv = static_pointer_cast<op::Convolution>(m_bn->get_argument(2));
if (m_conv->get_users().size() > 1) if (m_conv->get_users().size() > 1)
{ {
...@@ -233,21 +231,21 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -233,21 +231,21 @@ void pass::CoreFusion::construct_folded_batch_norm()
// new biases = -mean * gamma / sqrt(variance + epsilon) + beta // new biases = -mean * gamma / sqrt(variance + epsilon) + beta
auto bn_eps = op::Constant::create(element::f32, Shape{}, {m_bn->get_eps_value()}); auto bn_eps = op::Constant::create(element::f32, Shape{}, {m_bn->get_eps_value()});
auto var_eps = std::make_shared<op::Add>( auto var_eps = make_shared<op::Add>(
pattern_map[var], pattern_map[var],
std::make_shared<op::Broadcast>(bn_eps, pattern_map[var]->get_shape(), AxisSet{0})); make_shared<op::Broadcast>(bn_eps, pattern_map[var]->get_shape(), AxisSet{0}));
auto sqrt_var_eps = std::make_shared<op::Sqrt>(var_eps); auto sqrt_var_eps = make_shared<op::Sqrt>(var_eps);
auto mean_gamma = std::make_shared<op::Multiply>(pattern_map[mean], pattern_map[gamma]); auto mean_gamma = make_shared<op::Multiply>(pattern_map[mean], pattern_map[gamma]);
auto new_biases = std::make_shared<op::Subtract>( auto new_biases = make_shared<op::Subtract>(
pattern_map[beta], std::make_shared<op::Divide>(mean_gamma, sqrt_var_eps)); pattern_map[beta], make_shared<op::Divide>(mean_gamma, sqrt_var_eps));
auto weight_scaling = std::make_shared<op::Divide>(pattern_map[gamma], sqrt_var_eps); auto weight_scaling = make_shared<op::Divide>(pattern_map[gamma], sqrt_var_eps);
auto new_weights = std::make_shared<op::Multiply>( auto new_weights = make_shared<op::Multiply>(
pattern_map[filters], pattern_map[filters],
std::make_shared<op::Broadcast>( make_shared<op::Broadcast>(
weight_scaling, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3})); weight_scaling, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
auto conv = std::make_shared<op::Convolution>(pattern_map[input], auto conv = make_shared<op::Convolution>(pattern_map[input],
new_weights, new_weights,
m_conv->get_window_movement_strides(), m_conv->get_window_movement_strides(),
m_conv->get_window_dilation_strides(), m_conv->get_window_dilation_strides(),
...@@ -255,7 +253,7 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -255,7 +253,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
m_conv->get_padding_above(), m_conv->get_padding_above(),
m_conv->get_data_dilation_strides()); m_conv->get_data_dilation_strides());
auto conv_bias = auto conv_bias =
conv + std::make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3}); conv + make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3});
replace_node(m.get_match_root(), conv_bias); replace_node(m.get_match_root(), conv_bias);
return true; return true;
...@@ -270,33 +268,33 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -270,33 +268,33 @@ void pass::CoreFusion::construct_conv_affine_folding()
{ {
// A * Conv (input, filters) + B -> ConvBias (input, filters * A_c, B_c) // A * Conv (input, filters) + B -> ConvBias (input, filters * A_c, B_c)
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto input = std::make_shared<pattern::op::Label>(element::f32, shape); auto input = make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = make_shared<pattern::op::Label>(element::f32, shape);
auto conv = std::make_shared<op::Convolution>(input, auto conv = make_shared<op::Convolution>(input,
filters, filters,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv}); auto conv_label = make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
auto Ac = std::make_shared<pattern::op::Label>(element::f32, Shape{2}); auto Ac = make_shared<pattern::op::Label>(element::f32, Shape{2});
auto A = std::make_shared<op::Broadcast>(Ac, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3}); auto A = make_shared<op::Broadcast>(Ac, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
auto A_label = std::make_shared<pattern::op::Label>(A, nullptr, NodeVector{A}); auto A_label = make_shared<pattern::op::Label>(A, nullptr, NodeVector{A});
auto Bc = std::make_shared<pattern::op::Label>(element::f32, Shape{2}); auto Bc = make_shared<pattern::op::Label>(element::f32, Shape{2});
auto B = std::make_shared<op::Broadcast>(Bc, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3}); auto B = make_shared<op::Broadcast>(Bc, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
auto B_label = std::make_shared<pattern::op::Label>(B, nullptr, NodeVector{B}); auto B_label = make_shared<pattern::op::Label>(B, nullptr, NodeVector{B});
auto multiply = std::make_shared<op::Multiply>(conv_label, A_label); auto multiply = make_shared<op::Multiply>(conv_label, A_label);
auto add = std::make_shared<op::Add>(multiply, B_label); auto add = make_shared<op::Add>(multiply, B_label);
auto callback = [input, filters, conv_label, A_label, B_label](pattern::Matcher& m) { auto callback = [input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for conv affine folding against node = " NGRAPH_DEBUG << "In callback for conv affine folding against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv_m = std::static_pointer_cast<op::Convolution>(pattern_map[conv_label]); auto conv_m = static_pointer_cast<op::Convolution>(pattern_map[conv_label]);
if (conv_m->get_users().size() > 1) if (conv_m->get_users().size() > 1)
{ {
...@@ -308,8 +306,8 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -308,8 +306,8 @@ void pass::CoreFusion::construct_conv_affine_folding()
return false; return false;
} }
auto A_m = std::static_pointer_cast<op::Broadcast>(pattern_map[A_label]); auto A_m = static_pointer_cast<op::Broadcast>(pattern_map[A_label]);
auto B_m = std::static_pointer_cast<op::Broadcast>(pattern_map[B_label]); auto B_m = static_pointer_cast<op::Broadcast>(pattern_map[B_label]);
// Check if values are being broadcast along channel (2nd) dimension // Check if values are being broadcast along channel (2nd) dimension
auto is_channel_bcast = [](const shared_ptr<op::Broadcast>& bcast) { auto is_channel_bcast = [](const shared_ptr<op::Broadcast>& bcast) {
...@@ -342,8 +340,8 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -342,8 +340,8 @@ void pass::CoreFusion::construct_conv_affine_folding()
if (bcast->get_argument(0)->get_shape().size() == 2) if (bcast->get_argument(0)->get_shape().size() == 2)
{ {
Shape bshape{bcast->get_argument(0)->get_shape()[1]}; Shape bshape{bcast->get_argument(0)->get_shape()[1]};
return static_pointer_cast<Node>(std::make_shared<op::Reshape>( return static_pointer_cast<Node>(
bcast->get_argument(0), AxisVector{0, 1}, bshape)); make_shared<op::Reshape>(bcast->get_argument(0), AxisVector{0, 1}, bshape));
} }
throw ngraph_error("Unexpected shape for bcast input"); throw ngraph_error("Unexpected shape for bcast input");
}; };
...@@ -353,12 +351,11 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -353,12 +351,11 @@ void pass::CoreFusion::construct_conv_affine_folding()
// new weights = old weights * Ac_m // new weights = old weights * Ac_m
// new biases = Bc_m // new biases = Bc_m
auto filters_n = std::make_shared<op::Multiply>( auto filters_n = make_shared<op::Multiply>(
pattern_map[filters], pattern_map[filters],
std::make_shared<op::Broadcast>( make_shared<op::Broadcast>(Ac_m, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
Ac_m, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
auto conv_n = std::make_shared<op::Convolution>(pattern_map[input], auto conv_n = make_shared<op::Convolution>(pattern_map[input],
filters_n, filters_n,
conv_m->get_window_movement_strides(), conv_m->get_window_movement_strides(),
conv_m->get_window_dilation_strides(), conv_m->get_window_dilation_strides(),
...@@ -376,8 +373,7 @@ void pass::CoreFusion::construct_conv_affine_folding() ...@@ -376,8 +373,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv, static bool is_trivial_convolution(shared_ptr<op::Convolution> conv, bool skip_pad_checks = false)
bool skip_pad_checks = false)
{ {
Strides stride_1{1, 1}; Strides stride_1{1, 1};
CoordinateDiff pad_0{0, 0}; CoordinateDiff pad_0{0, 0};
...@@ -398,11 +394,11 @@ static bool are_img_dims_equal(Shape conv_shape, Shape image_shape) ...@@ -398,11 +394,11 @@ static bool are_img_dims_equal(Shape conv_shape, Shape image_shape)
return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1]; return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1];
} }
static std::shared_ptr<Node> reduce_broadcast(std::shared_ptr<Node> broadcast) static shared_ptr<Node> reduce_broadcast(shared_ptr<Node> broadcast)
{ {
const size_t H = 2; const size_t H = 2;
const size_t W = 3; const size_t W = 3;
auto matched_broadcast_w1 = std::static_pointer_cast<op::Broadcast>(broadcast); auto matched_broadcast_w1 = static_pointer_cast<op::Broadcast>(broadcast);
Shape shape_w1{matched_broadcast_w1->get_shape()}; Shape shape_w1{matched_broadcast_w1->get_shape()};
shape_w1[H] /= 2; shape_w1[H] /= 2;
shape_w1[W] /= 2; shape_w1[W] /= 2;
...@@ -448,8 +444,8 @@ void pass::CoreFusion::construct_reshape_broadcast() ...@@ -448,8 +444,8 @@ void pass::CoreFusion::construct_reshape_broadcast()
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto broadcast_m = std::static_pointer_cast<op::Broadcast>(m.get_match_root()); auto broadcast_m = static_pointer_cast<op::Broadcast>(m.get_match_root());
auto reshape1_m = std::static_pointer_cast<op::Reshape>(broadcast_m->get_argument(0)); auto reshape1_m = static_pointer_cast<op::Reshape>(broadcast_m->get_argument(0));
auto input_m = m.get_pattern_map()[input]; auto input_m = m.get_pattern_map()[input];
//it doesn't seem to make sense to support shapes : [0] or [1] //it doesn't seem to make sense to support shapes : [0] or [1]
...@@ -517,36 +513,36 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -517,36 +513,36 @@ void pass::CoreFusion::construct_optimized_strided_conv()
{ {
Shape win_size_1{1, 1, 1, 1}; Shape win_size_1{1, 1, 1, 1};
auto is_bc = pattern::has_class<op::Broadcast>(); auto is_bc = pattern::has_class<op::Broadcast>();
auto data_stride3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128}); auto data_stride3 = make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 128, 128});
auto weights_stride3 = std::make_shared<pattern::op::Label>(element::f32, win_size_1); auto weights_stride3 = make_shared<pattern::op::Label>(element::f32, win_size_1);
auto conv_stride3 = std::make_shared<op::Convolution>(data_stride3, weights_stride3); auto conv_stride3 = make_shared<op::Convolution>(data_stride3, weights_stride3);
auto conv_stride3_label = auto conv_stride3_label =
std::make_shared<pattern::op::Label>(conv_stride3, nullptr, NodeVector{conv_stride3}); make_shared<pattern::op::Label>(conv_stride3, nullptr, NodeVector{conv_stride3});
auto broadcast_w3_label = std::make_shared<pattern::op::Label>(conv_stride3_label, is_bc); auto broadcast_w3_label = make_shared<pattern::op::Label>(conv_stride3_label, is_bc);
auto add_w3 = std::make_shared<op::Add>(conv_stride3_label, broadcast_w3_label); auto add_w3 = make_shared<op::Add>(conv_stride3_label, broadcast_w3_label);
auto relu_w3 = std::make_shared<op::Relu>(add_w3); auto relu_w3 = make_shared<op::Relu>(add_w3);
auto weights_stride1 = std::make_shared<pattern::op::Label>(element::f32, win_size_1); auto weights_stride1 = make_shared<pattern::op::Label>(element::f32, win_size_1);
auto conv_stride1 = std::make_shared<op::Convolution>(relu_w3, weights_stride1); auto conv_stride1 = make_shared<op::Convolution>(relu_w3, weights_stride1);
auto conv_stride1_label = auto conv_stride1_label =
std::make_shared<pattern::op::Label>(conv_stride1, nullptr, NodeVector{conv_stride1}); make_shared<pattern::op::Label>(conv_stride1, nullptr, NodeVector{conv_stride1});
auto broadcast_w1_label = std::make_shared<pattern::op::Label>(conv_stride1_label, is_bc); auto broadcast_w1_label = make_shared<pattern::op::Label>(conv_stride1_label, is_bc);
auto add_w1 = std::make_shared<op::Add>(conv_stride1_label, broadcast_w1_label); auto add_w1 = make_shared<op::Add>(conv_stride1_label, broadcast_w1_label);
auto eltwise_arg_label = auto eltwise_arg_label =
std::make_shared<pattern::op::Label>(element::f32, conv_stride1->get_shape()); make_shared<pattern::op::Label>(element::f32, conv_stride1->get_shape());
auto add_two_convs = std::make_shared<op::Add>(add_w1, eltwise_arg_label); auto add_two_convs = make_shared<op::Add>(add_w1, eltwise_arg_label);
auto relu_two_convs = std::make_shared<op::Relu>(add_two_convs); auto relu_two_convs = make_shared<op::Relu>(add_two_convs);
auto eltwise_label = auto eltwise_label =
std::make_shared<pattern::op::Label>(relu_two_convs, nullptr, NodeVector{relu_two_convs}); make_shared<pattern::op::Label>(relu_two_convs, nullptr, NodeVector{relu_two_convs});
auto weights_eltwise = std::make_shared<pattern::op::Label>(element::f32, win_size_1); auto weights_eltwise = make_shared<pattern::op::Label>(element::f32, win_size_1);
auto eltwise_conv = std::make_shared<op::Convolution>(eltwise_label, weights_eltwise); auto eltwise_conv = make_shared<op::Convolution>(eltwise_label, weights_eltwise);
auto callback = [win_size_1, auto callback = [win_size_1,
eltwise_label, eltwise_label,
...@@ -568,12 +564,12 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -568,12 +564,12 @@ void pass::CoreFusion::construct_optimized_strided_conv()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_eltwise = pattern_map[eltwise_label]; auto m_eltwise = pattern_map[eltwise_label];
std::vector<std::shared_ptr<Node>> strided_convs; vector<shared_ptr<Node>> strided_convs;
for (auto n : m_eltwise->get_users()) for (auto n : m_eltwise->get_users())
{ {
if (is_used(n.get())) if (is_used(n.get()))
{ {
if (std::dynamic_pointer_cast<op::Convolution>(n) == nullptr) if (dynamic_pointer_cast<op::Convolution>(n) == nullptr)
{ {
NGRAPH_DEBUG << "Not all live users of element wise operation are Convolution"; NGRAPH_DEBUG << "Not all live users of element wise operation are Convolution";
return false; return false;
...@@ -606,7 +602,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -606,7 +602,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
NGRAPH_DEBUG << "element-wise isn't data"; NGRAPH_DEBUG << "element-wise isn't data";
return false; return false;
} }
auto sconv = std::static_pointer_cast<op::Convolution>(sc); auto sconv = static_pointer_cast<op::Convolution>(sc);
sparse_shape_index = shape_to_index(sconv->get_shape()); sparse_shape_index = shape_to_index(sconv->get_shape());
if (sparse_shape_index == 0) if (sparse_shape_index == 0)
{ {
...@@ -627,8 +623,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -627,8 +623,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
const size_t full_shape_index = sparse_shape_index - 1; const size_t full_shape_index = sparse_shape_index - 1;
auto m_conv_stride1 = auto m_conv_stride1 = static_pointer_cast<op::Convolution>(pattern_map[conv_stride1_label]);
std::static_pointer_cast<op::Convolution>(pattern_map[conv_stride1_label]);
if (!are_img_dims_equal(m_conv_stride1->get_shape(), supported_shapes[full_shape_index]) || if (!are_img_dims_equal(m_conv_stride1->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride1->get_argument(1)->get_shape(), win_size_1) || !are_img_dims_equal(m_conv_stride1->get_argument(1)->get_shape(), win_size_1) ||
...@@ -642,8 +637,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -642,8 +637,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
return false; return false;
} }
auto m_conv_stride3 = auto m_conv_stride3 = static_pointer_cast<op::Convolution>(pattern_map[conv_stride3_label]);
std::static_pointer_cast<op::Convolution>(pattern_map[conv_stride3_label]);
if (!are_img_dims_equal(m_conv_stride3->get_shape(), supported_shapes[full_shape_index]) || if (!are_img_dims_equal(m_conv_stride3->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride3->get_argument(1)->get_shape(), shape_3) || !are_img_dims_equal(m_conv_stride3->get_argument(1)->get_shape(), shape_3) ||
...@@ -657,31 +651,31 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -657,31 +651,31 @@ void pass::CoreFusion::construct_optimized_strided_conv()
return false; return false;
} }
auto conv_28w3s2 = std::make_shared<op::Convolution>(m_conv_stride3->get_argument(0), auto conv_28w3s2 = make_shared<op::Convolution>(m_conv_stride3->get_argument(0),
m_conv_stride3->get_argument(1), m_conv_stride3->get_argument(1),
stride_2, stride_2,
stride_1, stride_1,
pad_1, pad_1,
pad_1); pad_1);
auto new_add_conv_28w3s2 = std::make_shared<op::Add>( auto new_add_conv_28w3s2 =
conv_28w3s2, reduce_broadcast(pattern_map[broadcast_w3_label])); make_shared<op::Add>(conv_28w3s2, reduce_broadcast(pattern_map[broadcast_w3_label]));
auto new_relu_28w3s2 = std::make_shared<op::Relu>(new_add_conv_28w3s2); auto new_relu_28w3s2 = make_shared<op::Relu>(new_add_conv_28w3s2);
auto conv_28w1s1 = std::make_shared<op::Convolution>( auto conv_28w1s1 = make_shared<op::Convolution>(
new_relu_28w3s2, m_conv_stride1->get_argument(1), stride_1, stride_1); new_relu_28w3s2, m_conv_stride1->get_argument(1), stride_1, stride_1);
auto new_add_conv28s1 = std::make_shared<op::Add>( auto new_add_conv28s1 =
conv_28w1s1, reduce_broadcast(pattern_map[broadcast_w1_label])); make_shared<op::Add>(conv_28w1s1, reduce_broadcast(pattern_map[broadcast_w1_label]));
auto maxpool = auto maxpool =
std::make_shared<op::MaxPool>(pattern_map[eltwise_arg_label], Shape{1, 1}, stride_2); make_shared<op::MaxPool>(pattern_map[eltwise_arg_label], Shape{1, 1}, stride_2);
auto new_add_two_convs = std::make_shared<op::Add>(new_add_conv28s1, maxpool); auto new_add_two_convs = make_shared<op::Add>(new_add_conv28s1, maxpool);
auto new_relu_two_convs = std::make_shared<op::Relu>(new_add_two_convs); auto new_relu_two_convs = make_shared<op::Relu>(new_add_two_convs);
for (auto sconv : sconvs) for (auto sconv : sconvs)
{ {
auto sconv_28w1s1 = std::make_shared<op::Convolution>( auto sconv_28w1s1 = make_shared<op::Convolution>(
new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1); new_relu_two_convs, sconv->get_argument(1), stride_1, stride_1);
NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with " NGRAPH_DEBUG << "Replacing " << sconv->get_name() << " with "
<< sconv_28w1s1->get_name(); << sconv_28w1s1->get_name();
...@@ -708,9 +702,9 @@ void pass::CoreFusion::construct_reshape_softmax_reshape() ...@@ -708,9 +702,9 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto reshape2_m = std::static_pointer_cast<op::Reshape>(m.get_match_root()); auto reshape2_m = static_pointer_cast<op::Reshape>(m.get_match_root());
auto softmax_m = std::static_pointer_cast<op::Softmax>(reshape2_m->get_argument(0)); auto softmax_m = static_pointer_cast<op::Softmax>(reshape2_m->get_argument(0));
auto reshape1_m = std::static_pointer_cast<op::Reshape>(softmax_m->get_argument(0)); auto reshape1_m = static_pointer_cast<op::Reshape>(softmax_m->get_argument(0));
auto input_m = m.get_pattern_map()[input]; auto input_m = m.get_pattern_map()[input];
if (!reshape2_m->get_is_transpose() || !reshape1_m->get_is_transpose()) if (!reshape2_m->get_is_transpose() || !reshape1_m->get_is_transpose())
...@@ -741,39 +735,51 @@ void pass::CoreFusion::construct_reshape_softmax_reshape() ...@@ -741,39 +735,51 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void ngraph::pass::CoreFusion::construct_conv_bias() void pass::CoreFusion::construct_conv_bias()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = make_shared<pattern::op::Label>(element::f32, shape);
auto pbias = std::make_shared<pattern::op::Label>(element::f32, Shape{}); auto pbias = make_shared<pattern::op::Label>(element::f32, Shape{});
auto pbroadcast = std::make_shared<ngraph::op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3}); auto pbcast = make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pbcast_label = make_shared<pattern::op::Label>(pbcast, nullptr, NodeVector{pbcast});
auto reshape_pred = [](shared_ptr<Node> node) -> bool {
if (auto reshape = dynamic_pointer_cast<op::Reshape>(node))
{
auto ishape = reshape->get_input_shape(0);
auto oshape = reshape->get_shape();
// Callback will check that broadcast happens along channel (1) dimension.
// Reshape should not alter that
if (!reshape->get_is_transpose() && ishape.size() > 1 && oshape.size() > 1 &&
ishape[0] == oshape[0] && ishape[1] == oshape[1])
{
return true;
}
}
return false;
};
auto pskip = make_shared<pattern::op::Skip>(pbcast_label, reshape_pred);
auto pconv1 = std::make_shared<ngraph::op::Convolution>(data_batch, auto pconv1 = make_shared<op::Convolution>(data_batch,
filters, filters,
Strides{1, 1}, Strides{1, 1},
Strides{1, 1}, Strides{1, 1},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1; auto p_conv_bias = pskip + pconv1;
auto callback = [](pattern::Matcher& m) { auto callback = [pbcast_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = " NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv_m = auto conv_m = dynamic_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
std::dynamic_pointer_cast<ngraph::op::Convolution>(m.get_match_root()->get_argument(0));
auto bcast_m =
std::dynamic_pointer_cast<op::Broadcast>(m.get_match_root()->get_argument(1));
if (conv_m == nullptr || bcast_m == nullptr) if (conv_m == nullptr)
{ {
conv_m = std::static_pointer_cast<ngraph::op::Convolution>( conv_m = static_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(1));
m.get_match_root()->get_argument(1));
bcast_m = std::static_pointer_cast<op::Broadcast>(m.get_match_root()->get_argument(0));
} }
if (conv_m->get_shape().size() > 5 || conv_m->get_element_type() != element::f32) if (conv_m->get_shape().size() > 5 || conv_m->get_element_type() != element::f32)
...@@ -782,6 +788,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias() ...@@ -782,6 +788,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
return false; return false;
} }
auto bcast_m = static_pointer_cast<op::Broadcast>(pattern_map[pbcast_label]);
// Except for the 2nd axis (channel dimension), we should either be broadcasting // Except for the 2nd axis (channel dimension), we should either be broadcasting
// to it or the dimension size should be 1. // to it or the dimension size should be 1.
auto bcast_axes = bcast_m->get_broadcast_axes(); auto bcast_axes = bcast_m->get_broadcast_axes();
...@@ -798,17 +805,16 @@ void ngraph::pass::CoreFusion::construct_conv_bias() ...@@ -798,17 +805,16 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
{ {
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name() NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
<< "conv_bias bias shape != 1, requires reshape to match filter count."; << "conv_bias bias shape != 1, requires reshape to match filter count.";
auto order = ngraph::get_default_order(bias->get_shape()); auto order = get_default_order(bias->get_shape());
auto bias_reshape = std::make_shared<ngraph::op::Reshape>( auto bias_reshape =
bias, order, Shape{conv_m->get_input_shape(1)[0]}); make_shared<op::Reshape>(bias, order, Shape{conv_m->get_input_shape(1)[0]});
auto conv_bias = auto conv_bias = shared_ptr<Node>(new op::ConvolutionBias(conv_m, bias_reshape));
std::shared_ptr<Node>(new ngraph::op::ConvolutionBias(conv_m, bias_reshape)); replace_node(m.get_match_root(), conv_bias);
ngraph::replace_node(m.get_match_root(), conv_bias);
} }
else else
{ {
auto conv_bias = std::shared_ptr<Node>(new ngraph::op::ConvolutionBias(conv_m, bias)); auto conv_bias = shared_ptr<Node>(new op::ConvolutionBias(conv_m, bias));
ngraph::replace_node(m.get_match_root(), conv_bias); replace_node(m.get_match_root(), conv_bias);
} }
return true; return true;
}; };
...@@ -817,14 +823,14 @@ void ngraph::pass::CoreFusion::construct_conv_bias() ...@@ -817,14 +823,14 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE); this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
void ngraph::pass::CoreFusion::construct_conv_bias_add() void pass::CoreFusion::construct_conv_bias_add()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, shape); auto data_batch = make_shared<pattern::op::Label>(element::f32, shape);
auto filters = std::make_shared<pattern::op::Label>(element::f32, shape); auto filters = make_shared<pattern::op::Label>(element::f32, shape);
auto bias = std::make_shared<pattern::op::Label>(element::f32, Shape{shape[0]}); auto bias = make_shared<pattern::op::Label>(element::f32, Shape{shape[0]});
auto pconv = std::make_shared<ngraph::op::ConvolutionBias>(data_batch, auto pconv = make_shared<op::ConvolutionBias>(data_batch,
filters, filters,
bias, bias,
Strides{1, 1}, Strides{1, 1},
...@@ -832,8 +838,8 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add() ...@@ -832,8 +838,8 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
CoordinateDiff{0, 0}, CoordinateDiff{0, 0},
Strides{1, 1}); Strides{1, 1});
auto add_input = std::make_shared<pattern::op::Label>(element::f32, pconv->get_shape()); auto add_input = make_shared<pattern::op::Label>(element::f32, pconv->get_shape());
auto padd = std::make_shared<ngraph::op::Add>(add_input, pconv); auto padd = make_shared<op::Add>(add_input, pconv);
auto callback = [data_batch, filters](pattern::Matcher& m) { auto callback = [data_batch, filters](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_sum against " NGRAPH_DEBUG << "In a callback for construct_conv_sum against "
...@@ -841,13 +847,12 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add() ...@@ -841,13 +847,12 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
auto add_m = m.get_match_root(); auto add_m = m.get_match_root();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv_m = auto conv_m = dynamic_pointer_cast<op::ConvolutionBias>(add_m->get_argument(1));
std::dynamic_pointer_cast<ngraph::op::ConvolutionBias>(add_m->get_argument(1));
auto add_input_m = add_m->get_argument(0); auto add_input_m = add_m->get_argument(0);
if (!conv_m) if (!conv_m)
{ {
conv_m = std::static_pointer_cast<ngraph::op::ConvolutionBias>(add_m->get_argument(0)); conv_m = static_pointer_cast<op::ConvolutionBias>(add_m->get_argument(0));
add_input_m = add_m->get_argument(1); add_input_m = add_m->get_argument(1);
} }
...@@ -857,9 +862,8 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add() ...@@ -857,9 +862,8 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
return false; return false;
} }
auto conv_add = auto conv_add = shared_ptr<Node>(new op::ConvolutionBiasAdd(conv_m, add_input_m, false));
std::shared_ptr<Node>(new ngraph::op::ConvolutionBiasAdd(conv_m, add_input_m, false)); replace_node(m.get_match_root(), conv_add);
ngraph::replace_node(m.get_match_root(), conv_add);
return true; return true;
}; };
......
...@@ -1190,7 +1190,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1190,7 +1190,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass); REGISTER_KNOBBED_PASS_WITH_ARGS(CoreFusion, true, ngraph::pass, ngraph::pass::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
......
...@@ -363,6 +363,57 @@ TEST(core_fusion, conv_bias) ...@@ -363,6 +363,57 @@ TEST(core_fusion, conv_bias)
} }
} }
TEST(core_fusion, conv_bias_bcast_reshape)
{
// PaddlePaddle pattern
auto gen_f = [](bool with_fused_op) {
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto weights = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
auto bias = make_shared<op::Parameter>(element::f32, Shape{4});
if (with_fused_op)
{
return make_shared<Function>(make_shared<op::ConvolutionBias>(data, weights, bias),
ParameterVector{data, weights, bias});
}
else
{
auto conv = make_shared<op::Convolution>(data, weights);
auto bias_bcast = make_shared<op::Broadcast>(bias, Shape{2, 4, 12}, AxisSet{0, 2});
auto conv_bias =
conv + make_shared<op::Reshape>(bias_bcast, AxisVector{0, 1, 2}, conv->get_shape());
return make_shared<Function>(conv_bias, ParameterVector{data, weights, bias});
}
};
auto fused_f = gen_f(true);
auto decomp_f1 = gen_f(false);
auto decomp_f2 = gen_f(false);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1);
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : fused_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto fused_r = execute(fused_f, args, "INTERPRETER");
auto decomp_r1 = execute(decomp_f1, args, "INTERPRETER");
auto decomp_r2 = execute(decomp_f2, args, "INTERPRETER");
for (size_t i = 0; i < fused_r.size(); i++)
{
EXPECT_TRUE(test::all_close(fused_r.at(i), decomp_r1.at(i)));
EXPECT_TRUE(test::all_close(fused_r.at(i), decomp_r2.at(i)));
}
}
TEST(core_fusion, conv_bias_add) TEST(core_fusion, conv_bias_add)
{ {
auto gen_f = [](bool with_fused_op) { auto gen_f = [](bool with_fused_op) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment