Commit 341205cf authored by Amy Zhuang's avatar Amy Zhuang

Check broadcast axes instead of broadcast input shape.

Add comments.

Add more unit tests.
parent 0ad2a3dd
......@@ -650,6 +650,29 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
this->add_matcher(m, callback);
}
// graph before this fusion:
// input mean var gamma beta broadcast1_input broadcast2_input
// \ \ | / / / \
// BatchNormInference Broadcast1 Broadcast2
// \ / /
// Multiply /
// \ /
// Add
// |
// Relu
//
//
// graph after this fusion:
// input mean var gamma broadcast1_input beta broadcast2_input
// \ \ | \ / \ / /
// \ \ | Mulitply1 Multiply2 /
// \ \ | / \ /
// \ \ | / newAdd
// \ \| / /
// BatchNormInferenceRelu
//
// Multiply1, Multiply2, and newAdd operate on vectors while Multiply an Add operate on multi-dimensional matrices.
// Multiply1, Multiply2, and newAdd may be folded away with constant folding pass later.
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_infer_relu_with_multiply_add()
{
auto input_shape = Shape{1, 3, 2, 2};
......@@ -683,68 +706,81 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_infer_relu_with
auto add = std::make_shared<ngraph::op::Add>(multi_label, broadcast2_label);
auto prelu = std::make_shared<ngraph::op::Relu>(add);
auto callback =
[input, mean, var, gamma, beta, bn_label, multi_label, broadcast1_input, broadcast2_input](
pattern::Matcher& m) {
NGRAPH_DEBUG
<< "In callback for construct_batch_norm_infer_relu_with_multi_add against node = "
<< m.get_match_root()->get_name();
auto callback = [input,
mean,
var,
gamma,
beta,
bn_label,
multi_label,
broadcast1_input,
broadcast2_input,
broadcast1_label,
broadcast2_label](pattern::Matcher& m) {
NGRAPH_DEBUG
<< "In callback for construct_batch_norm_infer_relu_with_multi_add against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto pattern_map = m.get_pattern_map();
auto bn_match = pattern_map[bn_label];
if (bn_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Multiply isn't the only user of BatchNorm's output";
return false;
}
auto multi_match = pattern_map[multi_label];
if (multi_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Add isn't the only user of Multiply's output";
return false;
}
if (pattern_map[broadcast1_input]->output(0).get_shape() !=
pattern_map[gamma]->output(0).get_shape() ||
pattern_map[broadcast2_input]->output(0).get_shape() !=
pattern_map[gamma]->output(0).get_shape())
{
NGRAPH_DEBUG << "shapes of Broadcast input and gamma do not match";
return false;
}
auto bn_match = pattern_map[bn_label];
if (bn_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Multiply isn't the only user of BatchNorm's output";
return false;
}
auto multi_match = pattern_map[multi_label];
if (multi_match->get_users().size() > 1)
{
NGRAPH_DEBUG << "Add isn't the only user of Multiply's output";
return false;
}
auto new_gamma = std::make_shared<ngraph::op::Multiply>(pattern_map[gamma],
pattern_map[broadcast1_input]);
auto new_multi = std::make_shared<ngraph::op::Multiply>(pattern_map[beta],
pattern_map[broadcast1_input]);
auto new_beta =
std::make_shared<ngraph::op::Add>(new_multi, pattern_map[broadcast2_input]);
std::vector<size_t> vec{0};
for (auto i = 2; i < pattern_map[input]->output(0).get_shape().size(); i++)
{
vec.push_back(i);
}
AxisSet axisSet{vec};
if (std::static_pointer_cast<ngraph::op::Broadcast>(pattern_map[broadcast1_label])
->get_broadcast_axes() != axisSet ||
std::static_pointer_cast<ngraph::op::Broadcast>(pattern_map[broadcast2_label])
->get_broadcast_axes() != axisSet)
{
NGRAPH_DEBUG << "Broadcast axes is not {0, 2, ...}";
return false;
}
std::shared_ptr<Node> bn_relu;
if (auto bn_inference =
std::dynamic_pointer_cast<ngraph::op::BatchNormInference>(bn_match))
{
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_inference.get()))
{
return false;
}
bn_relu = std::make_shared<ngraph::op::BatchNormInferenceRelu>(
bn_inference->get_eps_value(),
new_gamma,
new_beta,
pattern_map[input],
pattern_map[mean],
pattern_map[var]);
}
auto new_gamma = std::make_shared<ngraph::op::Multiply>(pattern_map[gamma],
pattern_map[broadcast1_input]);
auto new_multi = std::make_shared<ngraph::op::Multiply>(pattern_map[beta],
pattern_map[broadcast1_input]);
auto new_beta = std::make_shared<ngraph::op::Add>(new_multi, pattern_map[broadcast2_input]);
if (bn_relu)
std::shared_ptr<Node> bn_relu;
if (auto bn_inference = std::dynamic_pointer_cast<ngraph::op::BatchNormInference>(bn_match))
{
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_inference.get()))
{
ngraph::replace_node(m.get_match_root(), bn_relu);
return true;
return false;
}
bn_relu =
std::make_shared<ngraph::op::BatchNormInferenceRelu>(bn_inference->get_eps_value(),
new_gamma,
new_beta,
pattern_map[input],
pattern_map[mean],
pattern_map[var]);
}
return false;
};
if (bn_relu)
{
ngraph::replace_node(m.get_match_root(), bn_relu);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu,
"CPUFusion.BatchNormInferReluWithMultiAdd");
......
......@@ -560,9 +560,8 @@ TEST(cpu_fusion, conv_bias_bprop)
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, batchnorm_multiply_add_relu)
static void test_batchnorm_multiply_add_relu(Shape input_shape)
{
auto input_shape = Shape{1, 3, 2, 2};
auto make_bn_relu_function = [&]() {
auto c_axis = input_shape[1];
auto input = make_shared<op::Parameter>(element::f32, input_shape);
......@@ -602,7 +601,7 @@ TEST(cpu_fusion, batchnorm_multiply_add_relu)
auto cpu_f = make_bn_relu_function();
auto int_f = make_bn_relu_function();
test::Uniform<float> rng(-10.0f, 10.0f);
test::Uniform<float> rng(1.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
......@@ -622,6 +621,75 @@ TEST(cpu_fusion, batchnorm_multiply_add_relu)
ASSERT_EQ(bn_relu, 1);
}
TEST(cpu_fusion, batchnorm_multiply_add_relu)
{
test_batchnorm_multiply_add_relu(Shape{1, 3, 2, 2});
test_batchnorm_multiply_add_relu(Shape{1, 2, 2, 2, 2});
test_batchnorm_multiply_add_relu(Shape{2, 2, 2, 4, 4});
}
TEST(cpu_fusion, batchnorm_multiply_add_relu_no_fusion)
{
auto input_shape = Shape{3, 3, 2, 2};
auto make_bn_relu_function = [&]() {
auto c_axis = input_shape[1];
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{c_axis};
auto mean = std::make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{c_axis};
auto var = std::make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{c_axis};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{c_axis};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto bn =
std::make_shared<ngraph::op::BatchNormInference>(eps, gamma, beta, input, mean, var);
std::vector<size_t> vec;
for (auto i = 1; i < input_shape.size(); i++)
{
vec.push_back(i);
}
auto broadcast1_input = std::make_shared<op::Parameter>(element::f32, Shape{3});
auto broadcast1 =
std::make_shared<ngraph::op::Broadcast>(broadcast1_input, input_shape, AxisSet(vec));
auto multiply = std::make_shared<ngraph::op::Multiply>(bn, broadcast1);
auto broadcast2_input = std::make_shared<op::Parameter>(element::f32, Shape{3});
auto broadcast2 =
std::make_shared<ngraph::op::Broadcast>(broadcast2_input, input_shape, AxisSet(vec));
auto add = std::make_shared<ngraph::op::Add>(multiply, broadcast2);
auto relu = std::make_shared<ngraph::op::Relu>(add);
auto f = make_shared<Function>(
relu,
ParameterVector{gamma, beta, input, mean, var, broadcast1_input, broadcast2_input});
return f;
};
auto cpu_f = make_bn_relu_function();
auto int_f = make_bn_relu_function();
test::Uniform<float> rng(1.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
size_t bn_relu = count_ops_of_type<op::BatchNormInferenceRelu>(cpu_f);
ASSERT_EQ(bn_relu, 0);
}
TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment