Commit dfa5d4d1 authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 's-barannikov/new_op_form/batch_norm' into cyphers/s-barannikov

parents 7fbdfd5c 77dd3bc2
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"}; using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, const string op::BatchNormTraining::type_name{"BatchNormTraining"};
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta, op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon) double epsilon)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, ...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input) const Output<Node>& input)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -66,15 +69,14 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types() ...@@ -66,15 +69,14 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>( return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
} }
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints, void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) const NodeVector& deltas)
{ {
auto gamma = input(0).get_source_output(); auto gamma = input(0).get_source_output();
...@@ -102,13 +104,13 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -102,13 +104,13 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"}; const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, op::BatchNormInference::BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, ...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance) const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -152,22 +154,21 @@ void ngraph::op::BatchNormInference::validate_and_infer_types() ...@@ -152,22 +154,21 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>( return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"}; const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta, const Output<Node>& delta,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph:: ...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta) const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2), return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
...@@ -39,9 +39,9 @@ namespace ngraph ...@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input, BatchNormTraining(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -66,9 +66,9 @@ namespace ngraph ...@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input); const Output<Node>& input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -101,11 +101,11 @@ namespace ngraph ...@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input, BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -128,11 +128,11 @@ namespace ngraph ...@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance); const Output<Node>& variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -165,24 +165,23 @@ namespace ngraph ...@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default; BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input, BatchNormTrainingBackprop(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> mean, const Output<Node>& mean,
Output<Node> variance, const Output<Node>& variance,
Output<Node> delta, const Output<Node>& delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input, const Output<Node>& input,
const Output<Node>& mean,
Output<Node> mean, const Output<Node>& variance,
Output<Node> variance, const Output<Node>& delta);
Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment