Commit 66198b33 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support for batchnorm+relu fusion for all batchnorm variants. (#903)

parent 82c19d24
...@@ -366,8 +366,12 @@ namespace ngraph ...@@ -366,8 +366,12 @@ namespace ngraph
writer.block_end(); writer.block_end();
} }
template <> void CPU_Emitter::emitBatchNorm(CPU_ExternalFunction* external_function,
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm) codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
bool append_relu)
{ {
const ngraph::op::BatchNorm* batchnorm = const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node); static_cast<const ngraph::op::BatchNorm*>(node);
...@@ -382,6 +386,17 @@ namespace ngraph ...@@ -382,6 +386,17 @@ namespace ngraph
<< args[1].get_name() << ", " << args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
if (append_relu)
{
ops.append_eltwise(
ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
}
if (batchnorm->get_training_flag() && args.size() == 3) if (batchnorm->get_training_flag() && args.size() == 3)
{ {
auto input_format = auto input_format =
...@@ -413,7 +428,8 @@ namespace ngraph ...@@ -413,7 +428,8 @@ namespace ngraph
variance_desc, variance_desc,
batchnorm->get_eps_value(), batchnorm->get_eps_value(),
false, false,
batchnorm->get_training_flag()); batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
...@@ -459,7 +475,8 @@ namespace ngraph ...@@ -459,7 +475,8 @@ namespace ngraph
variance_desc, variance_desc,
batchnorm->get_eps_value(), batchnorm->get_eps_value(),
true, true,
batchnorm->get_training_flag()); batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
...@@ -480,83 +497,23 @@ namespace ngraph ...@@ -480,83 +497,23 @@ namespace ngraph
} }
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormRelu) void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm)
{ {
if (!mkldnn_utils::use_mkldnn_kernel(node)) if (!mkldnn_utils::use_mkldnn_kernel(node))
{ {
throw ngraph_error("BatchNormRelu is only supported with MKLDNN kernel."); throw ngraph_error("BatchNorm is only supported with 4-D MKLDNN kernel.");
}
emitBatchNorm(external_function, writer, node, args, out, false);
} }
const ngraph::op::BatchNormRelu* batchnorm = template <>
static_cast<const ngraph::op::BatchNormRelu*>(node); void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormRelu)
if (!batchnorm->get_training_flag() || batchnorm->get_inputs().size() != 3)
{ {
throw ngraph_error("Only training batchnorm should have been fused"); if (!mkldnn_utils::use_mkldnn_kernel(node))
{
throw ngraph_error("BatchNormRelu is only supported with 4-D MKLDNN kernel.");
} }
emitBatchNorm(external_function, writer, node, args, out, true);
const float ops_scale = 1.f;
const float ops_alpha = -0.f; // relu negative slope
const float ops_beta = 0.f;
mkldnn::post_ops ops;
ops.append_eltwise(ops_scale, mkldnn::algorithm::eltwise_relu, ops_alpha, ops_beta);
writer.block_begin();
writer << "{\n";
// define weights
writer << "std::vector<" << args[0].get_element_type().c_type_string()
<< ">bn_weights(2*" << args[0].get_size() << ");\n";
writer << "memcpy(&bn_weights[0], " << args[0].get_name() << ", "
<< args[0].get_size() * args[0].get_element_type().size() << ");\n";
writer << "memcpy(&bn_weights[0]+" << args[0].get_size() << ", "
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag(),
ops);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
writer.block_end();
writer << "}\n";
} }
template <> template <>
......
...@@ -58,6 +58,13 @@ namespace ngraph ...@@ -58,6 +58,13 @@ namespace ngraph
{ {
} }
static void emitBatchNorm(CPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
bool append_relu = false);
private: private:
static std::string emit_vector(const TensorViewWrapper&, static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = ""); const std::string& name = "");
......
...@@ -76,11 +76,104 @@ ngraph::op::BatchNormRelu::BatchNormRelu(double eps, ...@@ -76,11 +76,104 @@ ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
add_output(input->get_element_type(), m_bn_variance_shape); add_output(input->get_element_type(), m_bn_variance_shape);
} }
ngraph::op::BatchNormRelu::BatchNormRelu(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training)
: RequiresTensorViewArgs("BatchNormRelu", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
, m_training(training)
{
if (m_bn_input_shape.size() != 4)
{
throw ngraph_error("input tensor to batchnorm must have rank 4");
}
else
{
this->m_bn_variance_shape.push_back(input->get_shape()[1]);
this->m_bn_mean_shape.push_back(input->get_shape()[1]);
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
{
if (get_argument(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma and beta shoud have rank 1");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
add_output(input->get_element_type(), m_bn_input_shape);
}
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormRelu::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNormRelu::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) if (this->m_training)
throw ngraph_error("Incorrect number of new arguments"); {
if (new_args.size() == 3)
{
return std::make_shared<BatchNormRelu>( return std::make_shared<BatchNormRelu>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2)); m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else if (new_args.size() == 5)
{
return std::make_shared<BatchNormRelu>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
true);
}
else
{
throw ngraph_error("BatchNormRelu: Incorrect number of new arguments");
}
}
else
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNormRelu>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
false);
}
} }
...@@ -35,6 +35,14 @@ namespace ngraph ...@@ -35,6 +35,14 @@ namespace ngraph
std::shared_ptr<Node> beta, std::shared_ptr<Node> beta,
std::shared_ptr<Node> input); std::shared_ptr<Node> input);
BatchNormRelu(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training = false);
const Shape& get_inputs_shape() const { return m_bn_input_shape; } const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; } const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; } const Shape& get_mean_shape() const { return m_bn_mean_shape; }
......
...@@ -775,12 +775,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -775,12 +775,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>( auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node()); m.match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node());
if (!m_bn->get_training_flag())
{
NGRAPH_DEBUG << " This is an inference batchnorm, so skipping fusion";
return false;
}
//as of now, only MKLDNN supports this fusion //as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4 //and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4) if (pattern_map[input]->get_shape().size() != 4)
...@@ -825,6 +819,64 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu() ...@@ -825,6 +819,64 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_stats()
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = std::make_shared<pattern::op::Label>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = std::make_shared<pattern::op::Label>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = std::make_shared<pattern::op::Label>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto prelu = std::make_shared<op::Relu>(bn);
ngraph::pattern::graph_rewrite_callback callback =
[input, mean, var, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_batch_norm_relu against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(
m.match_root()->get_inputs().at(0).get_output().get_node());
//as of now, only MKLDNN supports this fusion
//and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << " Input data's rank isn't equal to 4. Shape = "
<< pattern_map[input]->get_shape().size();
return false;
}
if (m_bn->get_users().size() > 1)
{
NGRAPH_DEBUG << "Relu isn't the only user of BatchNorm's output";
return false;
}
auto bn_relu = std::make_shared<op::BatchNormRelu>(m_bn->get_eps_value(),
pattern_map[gamma],
pattern_map[beta],
pattern_map[input],
pattern_map[mean],
pattern_map[var],
m_bn->get_training_flag());
ngraph::replace_node(m.match_root(), bn_relu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(prelu, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu() void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_relu()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
......
...@@ -48,6 +48,7 @@ public: ...@@ -48,6 +48,7 @@ public:
construct_sigmoid_bprop(); construct_sigmoid_bprop();
construct_conv_bias(); construct_conv_bias();
construct_batch_norm_relu(); construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu(); construct_conv_relu();
} }
...@@ -62,5 +63,6 @@ private: ...@@ -62,5 +63,6 @@ private:
void construct_zero_padded_conv(); void construct_zero_padded_conv();
void construct_zero_padded_conv_backprop_filters(); void construct_zero_padded_conv_backprop_filters();
void construct_batch_norm_relu(); void construct_batch_norm_relu();
void construct_batch_norm_relu_global_stats();
void construct_conv_relu(); void construct_conv_relu();
}; };
...@@ -1104,17 +1104,30 @@ namespace ngraph ...@@ -1104,17 +1104,30 @@ namespace ngraph
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
if (!bn->get_training_flag() || bn->get_inputs().size() != 3) if (bn->get_inputs().size() == 3)
{ {
throw ngraph_error("Only training batchnorm should have been fused");
}
prim_input_formats.push_back(memory::format::x); prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x); prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout); prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout); prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x); prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x); prim_output_formats.push_back(memory::format::x);
}
else if (bn->get_inputs().size() == 5)
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(input_layout);
}
else
{
throw ngraph_error(
"In CPU Layout: unknown number of inputs for BatchNormRelu " +
to_string(bn->get_inputs().size()));
}
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment