Commit 1d80cabe authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Pruthvi/bn to support globalstats (#783)

* WIP support bn training for global_stats

(cherry picked from commit eb81a37328ea177b1d58c9eebdbb345e0fa25f0d)

* - Style fix
- Fix test case

* Addressed PR comments
- added support for bn training/inference with a same ctor
- added more verbose comments in bn header

* Fixed bn serializer and default value in bn ctor for bwd compatibility

* proposed docs change

* - Addressed PR comments
  - added support to compute bn inference/training using same mkldnn kernel with global stats

* fix unit bn relu unit test
parent df845963
......@@ -74,14 +74,17 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
std::shared_ptr<ngraph::Node> variance,
bool training)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
, m_training(false)
, m_training(training)
{
const size_t INPUT_INDEX = 2;
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
......@@ -105,16 +108,15 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
throw ngraph_error(err_msg.c_str());
}
}
for (size_t index = 0; index < get_input_size(); index++)
{
if (index != 2 && get_input_op(index)->get_shape().size() != 1)
if (index != INPUT_INDEX && get_input_op(index)->get_shape().size() != 1)
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
if (index != 2 && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1])
if (index != INPUT_INDEX && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1])
{
auto err_msg = std::string(input_names[index]) +
" shape should match the input channel size (" +
......@@ -136,12 +138,25 @@ std::shared_ptr<ngraph::Node>
{
if (this->m_training)
{
if (new_args.size() != 3)
if (new_args.size() == 3)
{
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else if (new_args.size() == 5)
{
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
true);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else
{
......@@ -154,7 +169,8 @@ std::shared_ptr<ngraph::Node>
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4));
new_args.at(4),
false);
}
}
......@@ -236,6 +252,8 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto gamma = get_input_op(0);
auto beta = get_input_op(1);
auto input = get_input_op(2);
std::shared_ptr<Node> mean = nullptr;
std::shared_ptr<Node> var = nullptr;
if (!this->get_training_flag())
{
......@@ -247,16 +265,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
for (auto _input : get_output_inputs(0))
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());
if (this->get_training_flag() && get_input_size() == 3)
{
for (auto goe_input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(goe_input->get_node());
goes.at(goe->get_n()) = goe_input->get_node();
}
mean = goes.at(1);
var = goes.at(2);
}
else // BatchNorm Training with global stats
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node());
goes.at(goe->get_n()) = _input->get_node();
mean = get_input_op(3);
var = get_input_op(4);
}
auto mean = goes.at(1);
auto var = goes.at(2);
auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, var, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
......
......@@ -30,19 +30,55 @@ namespace ngraph
class BatchNorm : public util::RequiresTensorViewArgs
{
public:
// BatchNorm Training
// In this version of BatchNorm:
//
// MEAN AND VARIANCE: computed directly from the content of 'input'.
//
// OUTPUT VALUE: A tuple with the following structure:
// [0] - The normalization of 'input'.
// [1] - The per-channel means of (pre-normalized) 'input'.
// [2] - The per-channel variances of (pre-normalized) 'input'.
//
// AUTODIFF SUPPORT: yes: 'generate_adjoints(...)' works as expected.
//
// SHAPE DETAILS:
// gamma: must have rank 1, with the same span as input's channel axis.
// beta: must have rank 1, with the same span as input's channel axis.
// input: must have rank >= 2. The second dimension represents the channel axis
// and must have a span of at least 1.
// output[0]: shall have the same shape as 'input'.
// output[1]: shall have rank 1, with the same span as input's channel axis.
// output[2]: shall have rank 1, with the same span as input's channel axis.
BatchNorm(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
//BatchNorm Inference
// In this version of BatchNorm:
//
// MEAN AND VARIANCE: provided by the 'mean' and 'variance' parameters.
//
// OUTPUT VALUE: a single tensor with the normalized value of 'input'.
//
// AUTODIFF SUPPORT:
// - when 'training' is true, yes: 'generate_adjoints(...)' works as expected.
// - when 'training' is false, no: 'generate_adjoints(...) may throw an exception.
//
// SHAPE DETAILS:
// gamma: must have rank 1, with the same span as input's channel axis.
// beta: must have rank 1, with the same span as input's channel axis.
// input: must have rank >= 2. The second dimension represents the channel axis and
// must have a span of at least 1.
// mean: must have rank 1, with the same span as input's channel axis.
// variance: must have rank 1, with the same span as input's channel axis.
// output: shall have the same shape as 'input'.
BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
std::shared_ptr<ngraph::Node> variance,
bool training = false);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
......
......@@ -380,7 +380,7 @@ namespace ngraph
<< args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n";
if (batchnorm->get_training_flag()) //BatchNorm Training
if (batchnorm->get_training_flag() && args.size() == 3)
{
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
......@@ -410,6 +410,7 @@ namespace ngraph
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
......@@ -427,7 +428,7 @@ namespace ngraph
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
}
else //BatchNorm Inference
else
{
auto input_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
......@@ -455,6 +456,7 @@ namespace ngraph
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
true,
batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
......@@ -532,6 +534,7 @@ namespace ngraph
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
false,
batchnorm->get_training_flag(),
ops);
......
......@@ -578,6 +578,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps,
bool use_global_stats,
bool bn_training_flag,
const mkldnn::post_ops& pops)
{
......@@ -590,7 +591,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
mkldnn::primitive_attr bn_attr;
bn_attr.set_post_ops(pops);
if (bn_training_flag)
if (bn_training_flag && !use_global_stats)
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_training,
......@@ -612,7 +613,7 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
else
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_inference,
{{mkldnn::prop_kind::forward_training,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift |
......
......@@ -171,6 +171,7 @@ namespace ngraph
const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc,
const double eps,
bool use_global_stats,
bool bn_training_flag,
const mkldnn::post_ops& pops = mkldnn::post_ops());
......
......@@ -1026,7 +1026,7 @@ namespace ngraph
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
if (bn->get_training_flag())
if (bn->get_training_flag() && node->get_input_size() == 3)
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
......
......@@ -420,10 +420,15 @@ static shared_ptr<ngraph::Function>
{
auto epsilon = node_js.at("eps").get<double>();
bool training = get_or_default<bool>(node_js, "training", true);
if (training)
if (training && args.size() == 3)
{
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
}
else if (training && args.size() == 5)
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4], true);
}
else
{
node = make_shared<op::BatchNorm>(
......
......@@ -928,59 +928,6 @@ TEST(cpu_fusion, conv_relu_n2c1h2w2_2)
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
}
TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
std::vector<shared_ptr<runtime::TensorView>>
rnn_matrix_fusion_eval(const size_t time_steps,
const Shape& data_shape,
......
......@@ -331,3 +331,55 @@ TEST(cpu_test, batchnorm_fprop_inference_b2c2h2w1)
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
TEST(cpu_test, batchnorm_fprop_globalstats_b2c2w2h1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var, true);
auto f = make_shared<Function>(bn, op::ParameterVector{gamma, beta, input, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_gamma, _beta, _input, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment