Unverified Commit 429eae9a authored by Pruthvi's avatar Pruthvi Committed by GitHub

Fix bn constructor (#631)

* Fix bn construtor
    - assert if gamma or beta dont have rank 1
    - remove redundant checks

* - added gaurds to check if the input and delta shape to mkldnn bn fprop and bprop op has a rank of 4
parent bb2b9516
......@@ -42,10 +42,9 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
"input tensor must have atleast one channel axis for batch normalization");
}
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) &&
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1");
throw ngraph_error("gamma and beta shoud have rank 1");
}
if (gamma->get_shape().size() != beta->get_shape().size())
......
......@@ -351,21 +351,35 @@ namespace ngraph
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
{
auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
auto delta_shape = node->get_input_shape(5);
auto delta_rank = delta_shape.size();
if ((input_rank == 4 && delta_rank == 4 &&
node->get_input_element_type(5) == element::f32 &&
node->get_input_element_type(2) == element::f32))
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment