Commit d4f8bfdc authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Support for 5D batchnorm (#2055)

* - modified cpu_assignment pass to support bn with input 5D
- added test cases for 5D bn and 5D bn+relu

* - Address PR comments
- used mkldnn_utils to validate bn for mkldnn

* fix compilation error

* Addressed PR comments
- added helpers in mkldnn_utils for assigning ngraph Op as MKLDNN op
- helper funnction for bn mkldnn assignment

* fix clang error
parent c5dd80be
......@@ -658,3 +658,44 @@ bool runtime::cpu::mkldnn_utils::use_mkldnn_kernel(const ngraph::Node* node)
static_pointer_cast<ngraph::runtime::cpu::CPUOpAnnotations>(op_annotations)
->is_mkldnn_op());
}
void runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(Node* node)
{
auto ngraph_op = static_cast<op::Op*>(node);
auto op_annotations = std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
ngraph_op->set_op_annotations(op_annotations);
}
bool runtime::cpu::mkldnn_utils::can_use_mkldnn_batchnorm_fprop(const ngraph::Node* node)
{
auto input_rank = node->get_input_shape(2).size();
auto input_element_type = node->get_input_element_type(2);
if (((input_rank == 4 || input_rank == 5) && input_element_type == element::f32))
{
return true;
}
else
{
return false;
}
}
bool runtime::cpu::mkldnn_utils::can_use_mkldnn_batchnorm_bprop(const ngraph::Node* node)
{
auto input_rank = node->get_input_shape(2).size();
auto input_element_type = node->get_input_element_type(2);
auto delta_rank = node->get_input_shape(5).size();
auto delta_element_type = node->get_input_element_type(5);
if (((input_rank == 4 && delta_rank == 4) || (input_rank == 5 && delta_rank == 5)) &&
(input_element_type == element::f32) && (delta_element_type == element::f32))
{
return true;
}
else
{
return false;
}
}
......@@ -20,7 +20,9 @@
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
......@@ -69,8 +71,11 @@ namespace ngraph
const AxisVector& axis_list);
bool is_mkldnn_filter_format(mkldnn::memory::format fmt);
bool is_mkldnn_blocked_data_format(mkldnn::memory::format fmt);
bool can_use_mkldnn_batchnorm_fprop(const ngraph::Node* node);
bool can_use_mkldnn_batchnorm_bprop(const ngraph::Node* node);
bool use_mkldnn_kernel(const ngraph::Node* node);
void assign_mkldnn_kernel(Node* node);
std::map<element::Type, const mkldnn::memory::data_type>&
get_mkldnn_data_type_map();
......
......@@ -29,9 +29,9 @@ ngraph::op::BatchNormTrainingRelu::BatchNormTrainingRelu(double eps,
auto bn_input_shape = get_input_shape(INPUT);
if (bn_input_shape.size() != 4)
if (bn_input_shape.size() != 4 && bn_input_shape.size() != 5)
{
throw ngraph_error("input tensor to batchnorm must have rank 4");
throw ngraph_error("input tensor to batchnorm must have rank 4/rank5");
}
auto channel_shape = Shape{bn_input_shape.at(1)};
......@@ -87,9 +87,10 @@ ngraph::op::BatchNormInferenceRelu::BatchNormInferenceRelu(double eps,
{
constructor_validate_and_infer_types();
auto bn_input_shape = get_input_shape(INPUT);
if (bn_input_shape.size() != 4)
if (bn_input_shape.size() != 4 && bn_input_shape.size() != 5)
{
throw ngraph_error("input tensor to batchnorm must have rank 4");
throw ngraph_error("input tensor to batchnorm must have rank 4/rank5");
}
if (bn_input_shape[1] == 0)
......
......@@ -73,7 +73,6 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Add)
{
auto add = static_cast<op::Add*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg1_shape = node->get_input_shape(1);
auto arg0_rank = arg0_shape.size();
......@@ -87,18 +86,13 @@ namespace ngraph
node->get_input_element_type(1) == element::f32 && arg0_rank == 4 &&
arg1_rank == 4 && src_size > 64000)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
add->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Concat)
{
auto concat = static_cast<op::Concat*>(node);
if (node->get_input_element_type(0) == element::f32 &&
((node->get_input_shape(0)).size() == 4 ||
(node->get_input_shape(0)).size() == 2))
......@@ -118,10 +112,7 @@ namespace ngraph
if (!any_zero)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
concat->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
}
......@@ -129,56 +120,36 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Convolution)
{
auto convolution = static_cast<op::Convolution*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::Convolution>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::GroupConvolution)
{
auto convolution = static_cast<op::GroupConvolution*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolution>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::GroupConvolutionBias)
{
auto convolution = static_cast<op::GroupConvolutionBias*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::GroupConvolutionBias>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionRelu)
{
auto convolution = static_cast<op::ConvolutionRelu*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionRelu>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -216,28 +187,22 @@ namespace ngraph
}
}
static void assign_batchnorm_relu(Node* node)
{
if (node->get_argument(2 /*input data*/)->get_shape().size() == 4)
{
auto bn_relu = static_cast<op::Op*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
bn_relu->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormInferenceRelu)
{
assign_batchnorm_relu(node);
if (mkldnn_utils::can_use_mkldnn_batchnorm_fprop(node))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormTrainingRelu)
{
assign_batchnorm_relu(node);
if (mkldnn_utils::can_use_mkldnn_batchnorm_fprop(node))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
......@@ -261,10 +226,7 @@ namespace ngraph
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -289,24 +251,16 @@ namespace ngraph
(arg0_rank == 5 && arg1_rank == 5)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ConvolutionBias)
{
auto convolution = static_cast<op::ConvolutionBias*>(node);
if (mkldnn_utils::can_use_mkldnn_conv<ngraph::op::ConvolutionBias>(node))
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -329,10 +283,7 @@ namespace ngraph
if (!data_dilated && data_rank == 4 && delta_rank == 4 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
convolution->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -349,10 +300,7 @@ namespace ngraph
(arg0_rank == 5 && avg_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -369,10 +317,7 @@ namespace ngraph
(arg0_rank == 5 && avg_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
avg_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -389,10 +334,7 @@ namespace ngraph
(arg0_rank == 5 && max_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -408,10 +350,7 @@ namespace ngraph
if (arg0_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -428,10 +367,7 @@ namespace ngraph
(arg1_rank == 5 && max_pool->get_window_shape().size() == 3)) &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -447,10 +383,7 @@ namespace ngraph
if (arg1_rank == 4 && max_pool->get_window_shape().size() == 2 &&
node->get_input_element_type(1) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
max_pool->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -497,51 +430,37 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::LRN)
{
auto lrn = static_cast<op::LRN*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
if ((arg0_rank == 4) && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
lrn->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Sigmoid)
{
auto sigmoid = static_cast<op::Sigmoid*>(node);
if (node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
sigmoid->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::SigmoidBackprop)
{
auto sigmoid = static_cast<op::SigmoidBackprop*>(node);
if (node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
sigmoid->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{
auto relu_bprop = static_cast<op::ReluBackprop*>(node);
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
auto result_shape = node->get_output_shape(0);
......@@ -549,55 +468,34 @@ namespace ngraph
if ((arg0_rank == 4 || arg0_rank == 2) &&
node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
relu_bprop->set_op_annotations(op_annotations);
}
}
static void assign_batchnorm(Node* node)
{
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
{
auto batchnorm = static_cast<op::Op*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormTraining)
{
assign_batchnorm(node);
if (mkldnn_utils::can_use_mkldnn_batchnorm_fprop(node))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormInference)
{
assign_batchnorm(node);
if (mkldnn_utils::can_use_mkldnn_batchnorm_fprop(node))
{
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormTrainingBackprop)
{
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
auto delta_shape = node->get_input_shape(5);
auto delta_rank = delta_shape.size();
if ((input_rank == 4 && delta_rank == 4 &&
node->get_input_element_type(5) == element::f32 &&
node->get_input_element_type(2) == element::f32))
if (mkldnn_utils::can_use_mkldnn_batchnorm_bprop(node))
{
auto batchnorm = static_cast<op::BatchNormTrainingBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -614,11 +512,7 @@ namespace ngraph
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto lstm_node = static_cast<op::Lstm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
lstm_node->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -635,11 +529,7 @@ namespace ngraph
node->get_input_element_type(0) == element::f32 &&
node->get_input_element_type(1) == element::f32))
{
auto rnn_node = static_cast<op::Rnn*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
rnn_node->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -656,10 +546,7 @@ namespace ngraph
node->get_input_element_type(0) == element::f32 &&
softmax->get_axes().size() == 1)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
softmax->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -670,10 +557,7 @@ namespace ngraph
auto strides = slice->get_strides();
if (!is_strided(strides) && node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
slice->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -683,11 +567,7 @@ namespace ngraph
if (node->get_input_element_type(0) == element::u8 ||
node->get_input_element_type(0) == element::i8)
{
auto quantized_mp = static_cast<op::QuantizedMaxPool*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_mp->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -697,11 +577,7 @@ namespace ngraph
if (node->get_input_element_type(0) == element::u8 ||
node->get_input_element_type(0) == element::i8)
{
auto quantized_ap = static_cast<op::QuantizedAvgPool*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_ap->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -756,44 +632,30 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolution)
{
auto quantized_conv = static_cast<op::QuantizedConvolution*>(node);
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_conv->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionRelu)
{
auto quantized_conv_relu = static_cast<op::QuantizedConvolutionRelu*>(node);
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_conv_relu->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::QuantizedConvolutionBias)
{
auto quantized_conv_bias = static_cast<op::QuantizedConvolutionBias*>(node);
if (node->get_input_element_type(0) == element::u8 &&
node->get_input_element_type(1) == element::i8)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantized_conv_bias->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
......@@ -821,10 +683,7 @@ namespace ngraph
if (offset[0] != 0)
return;
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
dequantize->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
template <>
......@@ -862,10 +721,7 @@ namespace ngraph
return;
}
}
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
quantize->set_op_annotations(op_annotations);
runtime::cpu::mkldnn_utils::assign_mkldnn_kernel(node);
}
}
}
......
......@@ -302,12 +302,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
NGRAPH_DEBUG << "beta: " << pattern_map[beta_label]->get_name() << " "
<< pattern_map[beta_label]->get_shape().size();
// dont fuse if the inout doesnt have 4dims
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << "Input to bn doesnt not have a rank=4, so not fusing";
return false;
}
Shape bn_output_shape{m.get_match_root()->get_shape()};
Shape m_bn_mean_shape{pattern_map[mean_label]->get_shape()};
Shape m_bn_variance_shape{pattern_map[variance_label]->get_shape()};
......@@ -323,6 +317,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_fprop_bn()
auto bn_node = std::make_shared<op::BatchNormTraining>(
epsilon, pattern_map[gamma_label], pattern_map[beta_label], pattern_map[input]);
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_node.get()))
{
return false;
}
auto normalized_output = std::shared_ptr<Node>(new op::GetOutputElement(bn_node, 0));
ngraph::replace_node(m.get_match_root(), normalized_output);
......@@ -777,15 +775,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu()
auto m_bn = std::static_pointer_cast<op::BatchNormTraining>(
m.get_match_root()->get_argument(0)->get_inputs().at(0).get_output().get_node());
// as of now, only MKLDNN supports this fusion
// and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(m_bn.get()))
{
NGRAPH_DEBUG << " Input data's rank isn't equal to 4. Shape = "
<< pattern_map[input]->get_shape().size();
return false;
}
std::vector<std::shared_ptr<Node>> mgoes(m_bn->get_outputs().size());
for (auto bn_in : m_bn->get_output_inputs(0))
{
......@@ -849,15 +842,6 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
auto pattern_map = m.get_pattern_map();
// as of now, only MKLDNN supports this fusion
// and it requires input data's rank to be equal to 4
if (pattern_map[input]->get_shape().size() != 4)
{
NGRAPH_DEBUG << " Input data's rank isn't equal to 4. Shape = "
<< pattern_map[input]->get_shape().size();
return false;
}
auto bn_match = m.get_match_root()->get_inputs().at(0).get_output().get_node();
if (bn_match->get_users().size() > 1)
{
......@@ -868,6 +852,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
std::shared_ptr<Node> bn_relu;
if (auto bn_inference = std::dynamic_pointer_cast<op::BatchNormInference>(bn_match))
{
if (!mkldnn_utils::can_use_mkldnn_batchnorm_fprop(bn_inference.get()))
{
return false;
}
bn_relu = std::make_shared<op::BatchNormInferenceRelu>(bn_inference->get_eps_value(),
pattern_map[gamma],
pattern_map[beta],
......
......@@ -5176,6 +5176,60 @@ NGRAPH_TEST(${BACKEND_NAME}, batch_norm_fprop_b2c2h2w1)
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
NGRAPH_TEST(${BACKEND_NAME}, batchnorm_fprop_b2c2d2h1w1)
{
auto input_shape = Shape{2, 2, 2, 1, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1, 1};
auto bn = make_shared<op::BatchNormTraining>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
vector<float> expected_mean{0.583388f, 0.619252f};
vector<float> expected_variance{0.0119972f, 0.0282681f};
backend->call_with_validate(
f, {bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
NGRAPH_TEST(${BACKEND_NAME}, batch_norm_bprop_n4c3h2w2)
{
auto input_shape = Shape{4, 3, 2, 2};
......
......@@ -783,6 +783,56 @@ TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
read_vector<float>(result_variance_bnr)));
}
static void test_batchnorm_fprop_relu(Shape input_shape)
{
auto make_bn_relu_function = [&]() {
auto c_axis = input_shape[1];
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{c_axis};
auto var_shape = Shape{c_axis};
auto gamma_shape = Shape{c_axis};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{c_axis};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = input_shape;
auto bn = make_shared<op::BatchNormTraining>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto output_relu = std::make_shared<op::Relu>(output_rt);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_relu, mean_rt, variance_rt},
ParameterVector{input, gamma, beta});
return f;
};
auto cpu_f = make_bn_relu_function();
auto int_f = make_bn_relu_function();
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_f->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_f, args, "INTERPRETER");
auto cpu_results = execute(cpu_f, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
TEST(cpu_fusion, batchnorm_fprop_relu)
{
test_batchnorm_fprop_relu(Shape{1, 2, 2, 2});
test_batchnorm_fprop_relu(Shape{1, 2, 2, 2, 2});
test_batchnorm_fprop_relu(Shape{2, 2, 2, 4, 4});
}
TEST(cpu_fusion, fuse_conv_relu)
{
auto A = std::make_shared<op::Parameter>(element::f32, Shape{2, 1, 2, 2});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment