Commit 66198b33 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support for batchnorm+relu fusion for all batchnorm variants. (#903)

parent 82c19d24
......@@ -366,8 +366,12 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm)
void CPU_Emitter::emitBatchNorm(CPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
bool append_relu)
{
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
......@@ -382,6 +386,17 @@ namespace ngraph
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
if (append_relu)
{
ops.append_eltwise(
ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
}
if (batchnorm->get_training_flag() && args.size() == 3)
{
auto input_format =
......@@ -413,7 +428,8 @@ namespace ngraph
variance_desc,
batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag());
batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
......@@ -459,7 +475,8 @@ namespace ngraph
variance_desc,
batchnorm->get_eps_value(),
true,
batchnorm->get_training_flag());
batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
......@@ -480,83 +497,23 @@ namespace ngraph
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormRelu)
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm)
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNormRelu is only supported with MKLDNN kernel.");
throw ngraph_error("BatchNorm is only supported with 4-D MKLDNN kernel.");
}
emitBatchNorm(external_function, writer, node, args, out, false);
}
const ngraph::op::BatchNormRelu* batchnorm =
static_cast<const ngraph::op::BatchNormRelu*>(node);
if (!batchnorm->get_training_flag() || batchnorm->get_inputs().size() != 3)
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormRelu)
{
if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("Only training batchnorm should have been fused");
throw ngraph_error("BatchNormRelu is only supported with 4-D MKLDNN kernel.");
}
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
ops.append_eltwise(ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
writer.block_begin();
writer << "{\n";
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
writer.block_end();
writer << "}\n";
emitBatchNorm(external_function, writer, node, args, out, true);
}
template <>
......
......@@ -58,6 +58,13 @@ namespace ngraph
{
}
static void emitBatchNorm(CPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
bool append_relu = false);
private:
static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = "");
......
......@@ -76,11 +76,104 @@ ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
add_output(input->get_element_type(), m_bn_variance_shape);
}
ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training)
: RequiresTensorViewArgs("BatchNormRelu", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
, m_training(training)
{
if (m_bn_input_shape.size() != 4)
{
throw ngraph_error("input tensor to batchnorm must have rank 4");
}
else
{
this->m_bn_variance_shape.push_back(input->get_shape()[1]);
this->m_bn_mean_shape.push_back(input->get_shape()[1]);
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
{
if (get_argument(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma and beta shoud have rank 1");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
add_output(input->get_element_type(), m_bn_input_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormRelu::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNormRelu>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
if (this->m_training)
{
if (new_args.size() == 3)
{
return std::make_shared<BatchNormRelu>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else if (new_args.size() == 5)
{
return std::make_shared<BatchNormRelu>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
true);
}
else
{
throw ngraph_error("BatchNormRelu: Incorrect number of new arguments");
}
}
else
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNormRelu>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
false);
}
}
......@@ -35,6 +35,14 @@ namespace ngraph
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
BatchNormRelu(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training = false);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
......
......@@ -775,12 +775,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node());
if (!m_bn->get_training_flag())
{
NGRAPH_DEBUG << " This is an inference batchnorm, so skipping fusion";
return false;
}
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
......@@ -825,6 +819,64 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_stats()
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto prelu = std::make_shared<op::Relu>(bn);
ngraph::pattern::graph_rewrite_callback callback =
[input, mean, var, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_inputs().at(0).get_output().get_node());
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << " Input data's rank isn't equal to 4. Shape = "
<< pattern_map[input]->get_shape().size();
return false;
}
if (m_bn->get_users().size() > 1)
{
NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output";
return false;
}
auto bn_relu = std::make_shared<op::BatchNormRelu>(m_bn->get_eps_value(),
pattern_map[gamma],
pattern_map[beta],
pattern_map[input],
pattern_map[mean],
pattern_map[var],
m_bn->get_training_flag());
ngraph::replace_node(m.match_root(), bn_relu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{
Shape shape{2, 2, 1, 1};
......
......@@ -48,6 +48,7 @@ public:
construct_sigmoid_bprop();
construct_conv_bias();
construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu();
}
......@@ -62,5 +63,6 @@ private:
void construct_zero_padded_conv();
void construct_zero_padded_conv_backprop_filters();
void construct_batch_norm_relu();
void construct_batch_norm_relu_global_stats();
void construct_conv_relu();
};
......@@ -1104,17 +1104,30 @@ namespace ngraph
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
if (!bn->get_training_flag() || bn->get_inputs().size() != 3)
if (bn->get_inputs().size() == 3)
{
throw ngraph_error("Only training batchnorm should have been fused");
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
}
else if (bn->get_inputs().size() == 5)
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(input_layout);
}
else
{
throw ngraph_error(
"In CPU Layout: unknown number of inputs for BatchNormRelu " +
to_string(bn->get_inputs().size()));
}
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
node =
insert_input_conversions(external_function, node, prim_input_formats);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment