Commit 5394ad2d authored by Pruthvi's avatar Pruthvi Committed by adstraw

Pruthvi/bn inference (#670)

* Added new ctor for bn which supports Inference
- added mkldnn emitter code for bn inference
* Added test case for bn inference
- added support for layout propogation for bn inference
* added sanity checks for gamma, beta, mean, variance shape in bn
* added serializer support for bn inference
parent 6ebc3c8c
......@@ -25,10 +25,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input})
, m_bn_input_shape(input->get_shape())
, m_epsilon(eps)
, m_training(true)
{
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm much have tensor of atleast rank 2");
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
}
else
{
......@@ -39,7 +40,20 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have atleast one channel axis for batch normalization");
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
......@@ -62,12 +76,93 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
add_output(input->get_element_type(), m_bn_variance_shape);
}
ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
, m_training(false)
{
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance"};
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
for (size_t index = 0; index < get_input_size(); index++)
{
if (index != 2 && get_input_op(index)->get_shape().size() != 1)
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
if (index != 2 && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1])
{
auto err_msg = std::string(input_names[index]) +
" shape should match the input channel size (" +
std::to_string(m_bn_input_shape[1]) + ",)";
throw ngraph_error(err_msg.c_str());
}
}
if (variance->get_shape()[0] != mean->get_shape()[0])
{
throw ngraph_error("mean and variance should have same size");
}
add_output(input->get_element_type(), m_bn_input_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 3)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNorm>(m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
if (this->m_training)
{
if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4));
}
}
ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
......@@ -149,6 +244,10 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto beta = get_input_op(1);
auto input = get_input_op(2);
if (!this->get_training_flag())
{
throw ngraph_error("generate_adjoints called on BatchNormInference op " + this->get_name());
}
//Extract mean and variance outputs from BatchNorm
//as these are used by BatchNormBackprop.
//The users of the outputs (GetOutputElements' Inputs) aren't sorted
......
......@@ -30,15 +30,25 @@ namespace ngraph
class BatchNorm : public util::RequiresTensorViewArgs
{
public:
// BatchNorm Training
BatchNorm(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
//BatchNorm Inference
BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; }
bool get_training_flag() const { return m_training; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -51,6 +61,7 @@ namespace ngraph
Shape m_bn_variance_shape;
Shape m_bn_mean_shape;
double m_epsilon;
bool m_training;
};
class BatchNormBackprop : public util::RequiresTensorViewArgs
......
......@@ -384,44 +384,98 @@ namespace ngraph
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
if (batchnorm->get_training_flag()) //BatchNorm Training
{
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc =
mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto results_desc =
mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", " << out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4])
<< ", " << out[2].get_name() << ");\n";
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value());
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
}
else //BatchNorm Inference
{
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
auto variance_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc =
mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(args[4], variance_format);
auto results_desc =
mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", "
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
batchnorm->get_training_flag());
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[3].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << args[4].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
}
writer.indent--;
writer << "}\n";
}
......
......@@ -572,7 +572,8 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps)
const double eps,
bool bn_training_flag)
{
size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc);
......@@ -580,21 +581,43 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
if (bn_training_flag)
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
*m_mkldnn_primitives[variance_index]));
m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
else
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_inference,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift |
mkldnn::batch_normalization_flag::use_global_stats},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[mean_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[variance_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index])));
m_primitive_deps[batchnorm_index] = {
input_index, mean_index, variance_index, weights_index, result_index};
return batchnorm_index;
}
}
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
......
......@@ -169,7 +169,8 @@ namespace ngraph
const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps);
const double eps,
bool bn_training_flag);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc,
......
......@@ -1009,6 +1009,7 @@ namespace ngraph
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{
auto bn = static_cast<const ngraph::op::BatchNorm*>(node.get());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
......@@ -1016,12 +1017,25 @@ namespace ngraph
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
if (bn->get_training_flag())
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
}
else
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(input_layout);
}
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
......
......@@ -410,7 +410,15 @@ static shared_ptr<ngraph::Function>
else if (node_op == "BatchNorm")
{
auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
if (node_js.at("training"))
{
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
}
else
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]);
}
}
else if (node_op == "BatchNormBackprop")
{
......@@ -941,6 +949,7 @@ static json write(const Node& n, bool binary_constant_data)
{
auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value();
node["training"] = tmp->get_training_flag();
}
else if (node_op == "BatchNormBackprop")
{
......
......@@ -960,3 +960,55 @@ TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment