Unverified Commit 429eae9a authored by Pruthvi's avatar Pruthvi Committed by GitHub

Fix bn constructor (#631)

* Fix bn construtor
    - assert if gamma or beta dont have rank 1
    - remove redundant checks

* - added gaurds to check if the input and delta shape to mkldnn bn fprop and bprop op has a rank of 4
parent bb2b9516
...@@ -42,10 +42,9 @@ ngraph::op::BatchNorm::BatchNorm(double eps, ...@@ -42,10 +42,9 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
"input tensor must have atleast one channel axis for batch normalization"); "input tensor must have atleast one channel axis for batch normalization");
} }
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) && if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
{ {
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1"); throw ngraph_error("gamma and beta shoud have rank 1");
} }
if (gamma->get_shape().size() != beta->get_shape().size()) if (gamma->get_shape().size() != beta->get_shape().size())
......
...@@ -350,6 +350,10 @@ namespace ngraph ...@@ -350,6 +350,10 @@ namespace ngraph
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm) void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
{ {
auto batchnorm = static_cast<op::BatchNorm*>(node); auto batchnorm = static_cast<op::BatchNorm*>(node);
auto op_annotations = auto op_annotations =
...@@ -357,9 +361,18 @@ namespace ngraph ...@@ -357,9 +361,18 @@ namespace ngraph
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
batchnorm->set_op_annotations(op_annotations); batchnorm->set_op_annotations(op_annotations);
} }
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNormBackprop)
{
auto input_shape = node->get_input_shape(2);
auto input_rank = input_shape.size();
auto delta_shape = node->get_input_shape(5);
auto delta_rank = delta_shape.size();
if ((input_rank == 4 && delta_rank == 4 &&
node->get_input_element_type(5) == element::f32 &&
node->get_input_element_type(2) == element::f32))
{ {
auto batchnorm = static_cast<op::BatchNormBackprop*>(node); auto batchnorm = static_cast<op::BatchNormBackprop*>(node);
auto op_annotations = auto op_annotations =
...@@ -370,6 +383,7 @@ namespace ngraph ...@@ -370,6 +383,7 @@ namespace ngraph
} }
} }
} }
}
} }
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment