Commit 39c8cc7f authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

remove 5arg BNTraining (#1901)

* remove 5arg BNTraining

* Remove 5-arg BNWithStats from gpu/op.
parent 8ac7fecd
......@@ -44,18 +44,6 @@ ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
constructor_validate_and_infer_types();
}
ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input, mean, variance}))
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
}
void ngraph::op::BatchNormInference::validate_and_infer_types()
{
element::Type result_et;
......
......@@ -54,31 +54,6 @@ namespace ngraph
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
// In this version of BatchNorm:
//
// MEAN AND VARIANCE: provided by the 'mean' and 'variance' parameters.
//
// OUTPUT VALUE: a single tensor with the normalized value of 'input'.
// mean and variance will also be updated inplace
//
// AUTODIFF SUPPORT:
// 'generate_adjoints(...)' works as expected.
//
// SHAPE DETAILS:
// gamma: must have rank 1, with the same span as input's channel axis.
// beta: must have rank 1, with the same span as input's channel axis.
// input: must have rank >= 2. The second dimension represents the channel axis and
// must have a span of at least 1.
// mean: must have rank 1, with the same span as input's channel axis.
// variance: must have rank 1, with the same span as input's channel axis.
// output: shall have the same shape as 'input'.
BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
......
......@@ -32,25 +32,6 @@ ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
set_output_type(output_index++, input->get_element_type(), channel_shape);
}
ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training)
: ngraph::op::BatchNormTraining(eps, gamma, beta, input, mean, variance)
{
auto output_index = get_output_size();
set_output_size(output_index + 2);
Shape channel_shape{input->get_shape()[1]};
// saved batch mean
set_output_type(output_index++, input->get_element_type(), channel_shape);
// saved batch inverse variance
set_output_type(output_index++, input->get_element_type(), channel_shape);
}
std::shared_ptr<ngraph::Node> ngraph::op::gpu::BatchNormTrainingWithStats::copy_with_new_args(
const NodeVector& new_args) const
{
......
......@@ -39,14 +39,6 @@ namespace ngraph
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
BatchNormTrainingWithStats(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
bool training = false);
protected:
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment