Commit 5394ad2d authored by Pruthvi's avatar Pruthvi Committed by adstraw

Pruthvi/bn inference (#670)

* Added new ctor for bn which supports Inference
- added mkldnn emitter code for bn inference
* Added test case for bn inference
- added support for layout propogation for bn inference
* added sanity checks for gamma, beta, mean, variance shape in bn
* added serializer support for bn inference
parent 6ebc3c8c
...@@ -25,10 +25,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -25,10 +25,11 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input}) : RequiresTensorViewArgs("BatchNorm", {gamma, beta, input})
, m_bn_input_shape(input->get_shape()) , m_bn_input_shape(input->get_shape())
, m_epsilon(eps) , m_epsilon(eps)
, m_training(true)
{ {
if (m_bn_input_shape.size() < 2) if (m_bn_input_shape.size() < 2)
{ {
throw ngraph_error("input tensor to batchnorm much have tensor of atleast rank 2"); throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
} }
else else
{ {
...@@ -39,7 +40,20 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -39,7 +40,20 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
if (m_bn_input_shape[1] == 0) if (m_bn_input_shape[1] == 0)
{ {
throw ngraph_error( throw ngraph_error(
"input tensor must have atleast one channel axis for batch normalization"); "input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
} }
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1)) if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
...@@ -62,12 +76,93 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -62,12 +76,93 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
add_output(input->get_element_type(), m_bn_variance_shape); add_output(input->get_element_type(), m_bn_variance_shape);
} }
ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
, m_training(false)
{
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm must have tensor of at least rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have at least one channel axis for batch normalization");
}
auto et = input->get_element_type();
const char* input_names[] = {"gamma", "beta", "input", "mean", "variance"};
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_op(i)->get_element_type() != et)
{
auto err_msg = std::string("The element type of ") + input_names[i] +
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
for (size_t index = 0; index < get_input_size(); index++)
{
if (index != 2 && get_input_op(index)->get_shape().size() != 1)
{
auto err_msg = std::string(input_names[index]) + " should have rank of 1";
throw ngraph_error(err_msg.c_str());
}
if (index != 2 && get_input_op(index)->get_shape()[0] != m_bn_input_shape[1])
{
auto err_msg = std::string(input_names[index]) +
" shape should match the input channel size (" +
std::to_string(m_bn_input_shape[1]) + ",)";
throw ngraph_error(err_msg.c_str());
}
}
if (variance->get_shape()[0] != mean->get_shape()[0])
{
throw ngraph_error("mean and variance should have same size");
}
add_output(input->get_element_type(), m_bn_input_shape);
}
std::shared_ptr<ngraph::Node> std::shared_ptr<ngraph::Node>
ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const ngraph::op::BatchNorm::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 3) if (this->m_training)
throw ngraph_error("Incorrect number of new arguments"); {
return std::make_shared<BatchNorm>(m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2)); if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2));
}
else
{
if (new_args.size() != 5)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<BatchNorm>(m_epsilon,
new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4));
}
} }
ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps, ngraph::op::BatchNormBackprop::BatchNormBackprop(double eps,
...@@ -149,6 +244,10 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -149,6 +244,10 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto beta = get_input_op(1); auto beta = get_input_op(1);
auto input = get_input_op(2); auto input = get_input_op(2);
if (!this->get_training_flag())
{
throw ngraph_error("generate_adjoints called on BatchNormInference op " + this->get_name());
}
//Extract mean and variance outputs from BatchNorm //Extract mean and variance outputs from BatchNorm
//as these are used by BatchNormBackprop. //as these are used by BatchNormBackprop.
//The users of the outputs (GetOutputElements' Inputs) aren't sorted //The users of the outputs (GetOutputElements' Inputs) aren't sorted
......
...@@ -30,15 +30,25 @@ namespace ngraph ...@@ -30,15 +30,25 @@ namespace ngraph
class BatchNorm : public util::RequiresTensorViewArgs class BatchNorm : public util::RequiresTensorViewArgs
{ {
public: public:
// BatchNorm Training
BatchNorm(double eps, BatchNorm(double eps,
std::shared_ptr<Node> gamma, std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta, std::shared_ptr<Node> beta,
std::shared_ptr<Node> input); std::shared_ptr<Node> input);
//BatchNorm Inference
BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
const Shape& get_inputs_shape() const { return m_bn_input_shape; } const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; } const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; } const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
bool get_training_flag() const { return m_training; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -51,6 +61,7 @@ namespace ngraph ...@@ -51,6 +61,7 @@ namespace ngraph
Shape m_bn_variance_shape; Shape m_bn_variance_shape;
Shape m_bn_mean_shape; Shape m_bn_mean_shape;
double m_epsilon; double m_epsilon;
bool m_training;
}; };
class BatchNormBackprop : public util::RequiresTensorViewArgs class BatchNormBackprop : public util::RequiresTensorViewArgs
......
...@@ -384,44 +384,98 @@ namespace ngraph ...@@ -384,44 +384,98 @@ namespace ngraph
<< args[1].get_name() << ", " << args[1].get_name() << ", "
<< args[1].get_size() * args[1].get_element_type().size() << ");\n"; << args[1].get_size() * args[1].get_element_type().size() << ");\n";
auto input_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2); if (batchnorm->get_training_flag()) //BatchNorm Training
auto result_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0); {
auto mean_format = runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1); auto input_format =
auto variance_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2); auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto mean_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 1);
auto variance_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 2);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter(); auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()}; auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc = mkldnn_emitter->build_memory_descriptor(args[2], input_format); auto input_desc =
auto weights_desc = mkldnn_emitter->build_memory_descriptor( mkldnn_emitter->build_memory_descriptor(args[2], input_format);
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc); auto weights_desc = mkldnn_emitter->build_memory_descriptor(
auto results_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format); weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format); auto results_desc =
auto variance_desc = mkldnn_emitter->build_memory_descriptor(out[0], result_format);
mkldnn_emitter->build_memory_descriptor(out[2], variance_format); auto mean_desc = mkldnn_emitter->build_memory_descriptor(out[1], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(out[2], variance_format);
auto batchnorm_index =
mkldnn_emitter->build_batchnorm_forward(input_desc,
weights_desc,
results_desc,
mean_desc,
variance_desc,
batchnorm->get_eps_value(),
batchnorm->get_training_flag());
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", " << out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4])
<< ", " << out[2].get_name() << ");\n";
auto batchnorm_index = writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
mkldnn_emitter->build_batchnorm_forward(input_desc, << to_string(batchnorm_index) << ");\n";
weights_desc, }
results_desc, else //BatchNorm Inference
mean_desc, {
variance_desc, auto input_format =
batchnorm->get_eps_value()); runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 2);
auto mean_format = runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 3);
auto variance_format =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node, 4);
auto result_format =
runtime::cpu::mkldnn_utils::get_output_mkldnn_format(node, 0);
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto weights_shape = Shape{2, args[0].get_size()};
auto input_desc =
mkldnn_emitter->build_memory_descriptor(args[2], input_format);
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
weights_shape, args[0].get_element_type(), mkldnn::memory::format::nc);
auto mean_desc = mkldnn_emitter->build_memory_descriptor(args[3], mean_format);
auto variance_desc =
mkldnn_emitter->build_memory_descriptor(args[4], variance_format);
auto results_desc =
mkldnn_emitter->build_memory_descriptor(out[0], result_format);
auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index); auto batchnorm_index =
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", " mkldnn_emitter->build_batchnorm_forward(input_desc,
<< args[2].get_name() << ");\n"; weights_desc,
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) results_desc,
<< ", bn_weights.data());\n"; mean_desc,
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", " variance_desc,
<< out[0].get_name() << ");\n"; batchnorm->get_eps_value(),
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3]) << ", " batchnorm->get_training_flag());
<< out[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4]) << ", "
<< out[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, " auto& deps = mkldnn_emitter->get_primitive_deps(batchnorm_index);
<< to_string(batchnorm_index) << ");\n"; writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[2].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << args[3].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2])
<< ", " << args[4].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[3])
<< ", bn_weights.data());\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[4])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(batchnorm_index) << ");\n";
}
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
......
...@@ -572,7 +572,8 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_ ...@@ -572,7 +572,8 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc, const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc, const mkldnn::memory::desc& variance_desc,
const double eps) const double eps,
bool bn_training_flag)
{ {
size_t input_index = build_memory_primitive(input_desc); size_t input_index = build_memory_primitive(input_desc);
size_t weights_index = build_memory_primitive(weights_desc); size_t weights_index = build_memory_primitive(weights_desc);
...@@ -580,21 +581,43 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_ ...@@ -580,21 +581,43 @@ size_t MKLDNNEmitter::build_batchnorm_forward(const mkldnn::memory::desc& input_
size_t mean_index = build_memory_primitive(mean_desc); size_t mean_index = build_memory_primitive(mean_desc);
size_t variance_index = build_memory_primitive(variance_desc); size_t variance_index = build_memory_primitive(variance_desc);
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward( if (bn_training_flag)
{{mkldnn::prop_kind::forward_training, {
input_desc, size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
eps, {{mkldnn::prop_kind::forward_training,
mkldnn::batch_normalization_flag::use_scale_shift}, input_desc,
mkldnn_utils::global_cpu_engine}, eps,
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]), mkldnn::batch_normalization_flag::use_scale_shift},
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]), mkldnn_utils::global_cpu_engine},
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]), mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
*m_mkldnn_primitives[mean_index], mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
*m_mkldnn_primitives[variance_index])); static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index]),
*m_mkldnn_primitives[mean_index],
m_primitive_deps[batchnorm_index] = { *m_mkldnn_primitives[variance_index]));
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index; m_primitive_deps[batchnorm_index] = {
input_index, weights_index, result_index, mean_index, variance_index};
return batchnorm_index;
}
else
{
size_t batchnorm_index = insert_primitive(new mkldnn::batch_normalization_forward(
{{mkldnn::prop_kind::forward_inference,
input_desc,
eps,
mkldnn::batch_normalization_flag::use_scale_shift |
mkldnn::batch_normalization_flag::use_global_stats},
mkldnn_utils::global_cpu_engine},
mkldnn::primitive::at(*m_mkldnn_primitives[input_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[mean_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[variance_index]),
mkldnn::primitive::at(*m_mkldnn_primitives[weights_index]),
static_cast<mkldnn::memory>(*m_mkldnn_primitives[result_index])));
m_primitive_deps[batchnorm_index] = {
input_index, mean_index, variance_index, weights_index, result_index};
return batchnorm_index;
}
} }
size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc, size_t MKLDNNEmitter::build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
......
...@@ -169,7 +169,8 @@ namespace ngraph ...@@ -169,7 +169,8 @@ namespace ngraph
const mkldnn::memory::desc& result_desc, const mkldnn::memory::desc& result_desc,
const mkldnn::memory::desc& mean_desc, const mkldnn::memory::desc& mean_desc,
const mkldnn::memory::desc& variance_desc, const mkldnn::memory::desc& variance_desc,
const double eps); const double eps,
bool bn_training_flag);
size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc, size_t build_batchnorm_backward(const mkldnn::memory::desc& weights_desc,
const mkldnn::memory::desc& input_desc, const mkldnn::memory::desc& input_desc,
......
...@@ -1009,6 +1009,7 @@ namespace ngraph ...@@ -1009,6 +1009,7 @@ namespace ngraph
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm) void CPULayout::LAYOUT_DECL(ngraph::op::BatchNorm)
{ {
auto bn = static_cast<const ngraph::op::BatchNorm*>(node.get());
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get())) if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{ {
auto input_layout = auto input_layout =
...@@ -1016,12 +1017,25 @@ namespace ngraph ...@@ -1016,12 +1017,25 @@ namespace ngraph
vector<memory::format> prim_input_formats; vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats; vector<memory::format> prim_output_formats;
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x); if (bn->get_training_flag())
prim_input_formats.push_back(input_layout); {
prim_output_formats.push_back(input_layout); prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x); prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x); prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
prim_output_formats.push_back(memory::format::x);
prim_output_formats.push_back(memory::format::x);
}
else
{
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(memory::format::x);
prim_input_formats.push_back(memory::format::x);
prim_output_formats.push_back(input_layout);
}
node = node =
insert_input_conversions(external_function, node, prim_input_formats); insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats); set_output_layouts(node, prim_output_formats);
......
...@@ -410,7 +410,15 @@ static shared_ptr<ngraph::Function> ...@@ -410,7 +410,15 @@ static shared_ptr<ngraph::Function>
else if (node_op == "BatchNorm") else if (node_op == "BatchNorm")
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]); if (node_js.at("training"))
{
node = make_shared<op::BatchNorm>(epsilon, args[0], args[1], args[2]);
}
else
{
node = make_shared<op::BatchNorm>(
epsilon, args[0], args[1], args[2], args[3], args[4]);
}
} }
else if (node_op == "BatchNormBackprop") else if (node_op == "BatchNormBackprop")
{ {
...@@ -941,6 +949,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -941,6 +949,7 @@ static json write(const Node& n, bool binary_constant_data)
{ {
auto tmp = dynamic_cast<const op::BatchNorm*>(&n); auto tmp = dynamic_cast<const op::BatchNorm*>(&n);
node["eps"] = tmp->get_eps_value(); node["eps"] = tmp->get_eps_value();
node["training"] = tmp->get_training_flag();
} }
else if (node_op == "BatchNormBackprop") else if (node_op == "BatchNormBackprop")
{ {
......
...@@ -960,3 +960,55 @@ TEST(cpu_fusion, sigmoid_bprop_n1c1h4) ...@@ -960,3 +960,55 @@ TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f}; vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result))); EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
} }
TEST(cpu_fusion, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto _input = backend->make_primary_tensor_view(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->make_primary_tensor_view(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->make_primary_tensor_view(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->make_primary_tensor_view(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->make_primary_tensor_view(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->make_primary_tensor_view(element::f32, shape_r);
auto result_mean = backend->make_primary_tensor_view(element::f32, mean_shape);
auto result_variance = backend->make_primary_tensor_view(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
cf->call({bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment