Unverified Commit 429eae9a authored by Pruthvi's avatar Pruthvi Committed by GitHub

Fix bn constructor (#631)

* Fix bn construtor
    - assert if gamma or beta dont have rank 1
    - remove redundant checks

* - added gaurds to check if the input and delta shape to mkldnn bn fprop and bprop op has a rank of 4
parent bb2b9516
...@@ -42,10 +42,9 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -42,10 +42,9 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
"input tensor must have atleast one channel axis for batch normalization"); "input tensor must have atleast one channel axis for batch normalization");
} }
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) && if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
{ {
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1"); throw ngraph_error("gamma and beta shoud have rank 1");
} }
if (gamma->get_shape().size() != beta->get_shape().size()) if (gamma->get_shape().size() != beta->get_shape().size())
......
...@@ -351,21 +351,35 @@ namespace ngraph ...@@ -351,21 +351,35 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm) void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{ {
auto batchnorm = static_cast<op::BatchNorm*>(node); auto input_shape = node->get_input_shape(2);
auto op_annotations = auto input_rank = input_shape.size();
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
op_annotations->set_mkldnn_op(true); {
batchnorm->set_op_annotations(op_annotations); auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{ {
auto batchnorm = static_cast<op::BatchNormBackprop*>(node); auto input_shape = node->get_input_shape(2);
auto op_annotations = auto input_rank = input_shape.size();
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); auto delta_shape = node->get_input_shape(5);
op_annotations->set_mkldnn_op(true); auto delta_rank = delta_shape.size();
batchnorm->set_op_annotations(op_annotations); if ((input_rank == 4 && delta_rank == 4 &&
node->get_input_element_type(5) == element::f32 &&
node->get_input_element_type(2) == element::f32))
{
auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations);
}
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment