Commit 77dd3bc2 authored by Sergei Barannikov's avatar Sergei Barannikov

Make BatchNorm* constructors accept inputs by const reference

Other operations which have already been converted to new op form accept
`Output<Node>`` by const reference. `BatchNorm*` ops were an exception.
parent 5465677f
......@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon)
const string op::BatchNormTraining::type_name{"BatchNormTraining"};
op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input})
, m_epsilon(epsilon)
{
......@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input)
op::BatchNormTraining::BatchNormTraining(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input)
: Op({gamma, beta, input})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormTraining::validate_and_infer_types()
void op::BatchNormTraining::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
......@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
}
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas)
{
auto gamma = input(0).get_source_output();
auto beta = input(1).get_source_output();
......@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta);
}
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon)
op::BatchNormInference::BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon)
{
......@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
}
// DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance)
op::BatchNormInference::BatchNormInference(double eps,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormInference::validate_and_infer_types()
void op::BatchNormInference::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
......@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
}
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta,
double epsilon)
op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta,
double epsilon)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
......@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types();
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta)
op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
......@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
void op::BatchNormTrainingBackprop::validate_and_infer_types()
{
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
......@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape);
}
std::shared_ptr<ngraph::Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
std::shared_ptr<Node>
op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
......@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input);
void validate_and_infer_types() override;
......@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
BatchNormInference(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance);
void validate_and_infer_types() override;
......@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
BatchNormTrainingBackprop(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta,
double epsilon);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta);
const Output<Node>& gamma,
const Output<Node>& beta,
const Output<Node>& input,
const Output<Node>& mean,
const Output<Node>& variance,
const Output<Node>& delta);
void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment