Unverified Commit c9bab318 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Update BatchNorm to set_output_size at construction time (#1963)

parent deabb7bc
...@@ -31,6 +31,7 @@ ngraph::op::BatchNormInference::BatchNormInference(double eps, ...@@ -31,6 +31,7 @@ ngraph::op::BatchNormInference::BatchNormInference(double eps,
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance})) : Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance}))
, m_epsilon(eps) , m_epsilon(eps)
{ {
set_output_size(1);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -41,6 +42,7 @@ ngraph::op::BatchNormTraining::BatchNormTraining(double eps, ...@@ -41,6 +42,7 @@ ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input})) : Op("BatchNormTraining", check_single_output_args({gamma, beta, input}))
, m_epsilon(eps) , m_epsilon(eps)
{ {
set_output_size(3);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -63,7 +65,6 @@ void ngraph::op::BatchNormInference::validate_and_infer_types() ...@@ -63,7 +65,6 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
get_input_partial_shape(INPUT_MEAN), get_input_partial_shape(INPUT_MEAN),
get_input_partial_shape(INPUT_VARIANCE)); get_input_partial_shape(INPUT_VARIANCE));
set_output_size(1);
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
...@@ -82,7 +83,6 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types() ...@@ -82,7 +83,6 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
get_input_partial_shape(INPUT_GAMMA), get_input_partial_shape(INPUT_GAMMA),
get_input_partial_shape(INPUT_BETA)); get_input_partial_shape(INPUT_BETA));
set_output_size(3);
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
set_output_type(1, result_et, result_channel_shape); set_output_type(1, result_et, result_channel_shape);
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
...@@ -117,6 +117,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( ...@@ -117,6 +117,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
, m_epsilon(eps) , m_epsilon(eps)
{ {
set_output_size(3);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -157,7 +158,6 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -157,7 +158,6 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
get_input_partial_shape(INPUT_MEAN), get_input_partial_shape(INPUT_MEAN),
get_input_partial_shape(INPUT_VARIANCE)); get_input_partial_shape(INPUT_VARIANCE));
set_output_size(3);
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
set_output_type(1, result_et, result_channel_shape); set_output_type(1, result_et, result_channel_shape);
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment