Commit 1d80cabe authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/bn to support globalstats (#783)

* WIP support bn training for global_stats

(cherry picked from commit eb81a37328ea177b1d58c9eebdbb345e0fa25f0d)

* - Style fix
- Fix test case

* Addressed PR comments
- added support for bn training/inference with a same ctor
- added more verbose comments in bn header

* Fixed bn serializer and default value in bn ctor for bwd compatibility

* proposed docs change

* - Addressed PR comments
  - added support to compute bn inference/training using same mkldnn kernel with global stats

* fix unit bn relu unit test
parent df845963
...@@ -74,14 +74,17 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -74,14 +74,17 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> beta, std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance) std::shared_ptr<ngraph::Node> variance,
bool training)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance}) : RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape()) , m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape()) , m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape()) , m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps) , m_epsilon(eps)
, m_training(false) , m_training(training)
{ {
const size_t INPUT_INDEX = 2;
if (m_bn_input_shape.size() < 2) if (m_bn_input_shape.size() < 2)
{ {
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2"); throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
...@@ -105,16 +108,15 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -105,16 +108,15 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
throw ngraph_error(err_msg.c_str()); throw ngraph_error(err_msg.c_str());
} }
} }
for (size_t index = 0; index < get_input_size(); index++) for (size_t index = 0; index < get_input_size(); index++)
{ {
if (index != 2 && get_input_op(index)->get_shape().size() != 1) if (index != INPUT_INDEX && get_input_op(index)->get_shape().size() != 1)
{ {
auto err_msg = std::string(input_names[index]) + " should have rank of 1"; auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str()); throw ngraph_error(err_msg.c_str());
} }
if (index != 2 && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1]) if (index != INPUT_INDEX && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1])
{ {
auto err_msg = std::string(input_names[index]) + auto err_msg = std::string(input_names[index]) +
" shape should match the input channel size (" + " shape should match the input channel size (" +
...@@ -136,12 +138,25 @@ std::shared_ptr<ngraph::Node> ...@@ -136,12 +138,25 @@ std::shared_ptr<ngraph::Node>
{ {
if (this->m_training) if (this->m_training)
{ {
if (new_args.size() != 3) if (new_args.size() == 3)
{
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else if (new_args.size() == 5)
{
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
true);
}
else
{ {
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
} }
else else
{ {
...@@ -154,7 +169,8 @@ std::shared_ptr<ngraph::Node> ...@@ -154,7 +169,8 @@ std::shared_ptr<ngraph::Node>
new_args.at(1), new_args.at(1),
new_args.at(2), new_args.at(2),
new_args.at(3), new_args.at(3),
new_args.at(4)); new_args.at(4),
false);
} }
} }
...@@ -236,6 +252,8 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -236,6 +252,8 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto gamma = get_input_op(0); auto gamma = get_input_op(0);
auto beta = get_input_op(1); auto beta = get_input_op(1);
auto input = get_input_op(2); auto input = get_input_op(2);
std::shared_ptr<Node> mean = nullptr;
std::shared_ptr<Node> var = nullptr;
if (!this->get_training_flag()) if (!this->get_training_flag())
{ {
...@@ -247,16 +265,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -247,16 +265,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs //and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted //Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details //Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
for (auto _input : get_output_inputs(0)) std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
if (this->get_training_flag() && get_input_size() == 3)
{
for (auto goe_input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
mean = goes.at(1);
var = goes.at(2);
}
else // BatchNorm Training with global stats
{ {
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node()); mean = get_input_op(3);
goes.at(goe->get_n()) = _input->get_node(); var = get_input_op(4);
} }
auto mean = goes.at(1);
auto var = goes.at(2);
auto bbn = std::make_shared<op::BatchNormBackprop>( auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, var, delta); get_eps_value(), gamma, beta, input, mean, var, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0); auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
......
...@@ -30,19 +30,55 @@ namespace ngraph ...@@ -30,19 +30,55 @@ namespace ngraph
class BatchNorm : public util::RequiresTensorViewArgs class BatchNorm : public util::RequiresTensorViewArgs
{ {
public: public:
// BatchNorm Training // In this version of BatchNorm:
//
// MEAN AND VARIANCE: computed directly from the content of 'input'.
//
// OUTPUT VALUE: A tuple with the following structure:
// [0] - The normalization of 'input'.
// [1] - The per-channel means of (pre-normalized) 'input'.
// [2] - The per-channel variances of (pre-normalized) 'input'.
//
// AUTODIFF SUPPORT: yes: 'generate_adjoints(...)' works as expected.
//
// SHAPE DETAILS:
// gamma: must have rank 1, with the same span as input's channel axis.
// beta: must have rank 1, with the same span as input's channel axis.
// input: must have rank >= 2. The second dimension represents the channel axis
// and must have a span of at least 1.
// output[0]: shall have the same shape as 'input'.
// output[1]: shall have rank 1, with the same span as input's channel axis.
// output[2]: shall have rank 1, with the same span as input's channel axis.
BatchNorm(double eps, BatchNorm(double eps,
std::shared_ptr<Node> gamma, std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta, std::shared_ptr<Node> beta,
std::shared_ptr<Node> input); std::shared_ptr<Node> input);
//BatchNorm Inference // In this version of BatchNorm:
//
// MEAN AND VARIANCE: provided by the 'mean' and 'variance' parameters.
//
// OUTPUT VALUE: a single tensor with the normalized value of 'input'.
//
// AUTODIFF SUPPORT:
// - when 'training' is true, yes: 'generate_adjoints(...)' works as expected.
// - when 'training' is false, no: 'generate_adjoints(...) may throw an exception.
//
// SHAPE DETAILS:
// gamma: must have rank 1, with the same span as input's channel axis.
// beta: must have rank 1, with the same span as input's channel axis.
// input: must have rank >= 2. The second dimension represents the channel axis and
// must have a span of at least 1.
// mean: must have rank 1, with the same span as input's channel axis.
// variance: must have rank 1, with the same span as input's channel axis.
// output: shall have the same shape as 'input'.
BatchNorm(double eps, BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma, std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance); std::shared_ptr<ngraph::Node> variance,
bool training = false);
const Shape& get_inputs_shape() const { return m_bn_input_shape; } const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; } const Shape& get_variance_shape() const { return m_bn_variance_shape; }
......
...@@ -380,7 +380,7 @@ namespace ngraph ...@@ -380,7 +380,7 @@ namespace ngraph
<< args[1].get_name() << ", " << args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
if (batchnorm->get_training_flag()) //BatchNorm Training if (batchnorm->get_training_flag() && args.size() == 3)
{ {
auto input_format = auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2); runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
...@@ -410,6 +410,7 @@ namespace ngraph ...@@ -410,6 +410,7 @@ namespace ngraph
mean_desc, mean_desc,
variance_desc, variance_desc,
batchnorm->get_eps_value(), batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag()); batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
...@@ -427,7 +428,7 @@ namespace ngraph ...@@ -427,7 +428,7 @@ namespace ngraph
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n"; << to_string(batchnorm_index) << ");\n";
} }
else //BatchNorm Inference else
{ {
auto input_format = auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2); runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
...@@ -455,6 +456,7 @@ namespace ngraph ...@@ -455,6 +456,7 @@ namespace ngraph
mean_desc, mean_desc,
variance_desc, variance_desc,
batchnorm->get_eps_value(), batchnorm->get_eps_value(),
true,
batchnorm->get_training_flag()); batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
...@@ -532,6 +534,7 @@ namespace ngraph ...@@ -532,6 +534,7 @@ namespace ngraph
mean_desc, mean_desc,
variance_desc, variance_desc,
batchnorm->get_eps_value(), batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag(), batchnorm->get_training_flag(),
ops); ops);
......
...@@ -578,6 +578,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_ ...@@ -578,6 +578,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
const mkldnn::memory::desc& mean_desc, const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc, const mkldnn::memory::desc& variance_desc,
const double eps, const double eps,
bool use_global_stats,
bool bn_training_flag, bool bn_training_flag,
const mkldnn::post_ops& pops) const mkldnn::post_ops& pops)
{ {
...@@ -590,7 +591,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_ ...@@ -590,7 +591,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
mkldnn::primitive_attr bn_attr; mkldnn::primitive_attr bn_attr;
bn_attr.set_post_ops(pops); bn_attr.set_post_ops(pops);
if (bn_training_flag) if (bn_training_flag && !use_global_stats)
{ {
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward( size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training, {{mkldnn::prop_kind::forward_training,
...@@ -612,7 +613,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_ ...@@ -612,7 +613,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
else else
{ {
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward( size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_inference, {{mkldnn::prop_kind::forward_training,
input_desc, input_desc,
eps, eps,
mkldnn::batch_normalization_flag::use_scale_shift | mkldnn::batch_normalization_flag::use_scale_shift |
......
...@@ -171,6 +171,7 @@ namespace ngraph ...@@ -171,6 +171,7 @@ namespace ngraph
const mkldnn::memory::desc& mean_desc, const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc, const mkldnn::memory::desc& variance_desc,
const double eps, const double eps,
bool use_global_stats,
bool bn_training_flag, bool bn_training_flag,
const mkldnn::post_ops& pops = mkldnn::post_ops()); const mkldnn::post_ops& pops = mkldnn::post_ops());
......
...@@ -1026,7 +1026,7 @@ namespace ngraph ...@@ -1026,7 +1026,7 @@ namespace ngraph
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
if (bn->get_training_flag()) if (bn->get_training_flag() && node->get_input_size() == 3)
{ {
prim_input_formats.push_back(memory::format::x); prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x); prim_input_formats.push_back(memory::format::x);
......
...@@ -420,10 +420,15 @@ static shared_ptr<ngraph::Function> ...@@ -420,10 +420,15 @@ static shared_ptr<ngraph::Function>
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true); bool training = get_or_default<bool>(node_js, "training", true);
if (training) if (training && args.size() == 3)
{ {
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]); node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
} }
else if (training && args.size() == 5)
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4], true);
}
else else
{ {
node = make_shared<op::BatchNorm>( node = make_shared<op::BatchNorm>(
......
...@@ -928,59 +928,6 @@ TEST(cpu_fusion, conv_relu_n2c1h2w2_2) ...@@ -928,59 +928,6 @@ TEST(cpu_fusion, conv_relu_n2c1h2w2_2)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0))); EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
} }
TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
std::vector<shared_ptr<runtime::TensorView>> std::vector<shared_ptr<runtime::TensorView>>
rnn_matrix_fusion_eval(const size_t time_steps, rnn_matrix_fusion_eval(const size_t time_steps,
const Shape& data_shape, const Shape& data_shape,
......
...@@ -331,3 +331,55 @@ TEST(cpu_test, batchnorm_fprop_inference_b2c2h2w1) ...@@ -331,3 +331,55 @@ TEST(cpu_test, batchnorm_fprop_inference_b2c2h2w1)
ASSERT_TRUE( ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f)); ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
} }
TEST(cpu_test, batchnorm_fprop_globalstats_b2c2w2h1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var, true);
auto f = make_shared<Function>(bn, op::ParameterVector{gamma, beta, input, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_gamma, _beta, _input, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment