Commit 722903ac authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4m: BatchNorm and backprops (#1904)

* Implement partial shape/type propagation for Convolution; fail for want of unit tests

* Implement unit tests for partial shapes/types for Convolution

* Implement partial shape/type validation for BatchNormInference, BatchNormTraining, BatchNormTrainingBackprop with unit tests

* Formatting

* Update CPU and GPU backends to deal with elimination of BatchNormBase

* Update BatchNormTrainingWithStats to use templated emit_BatchNorm function

* Restore five-argument BatchNormTraining ctor for now; #1901 will eliminate it

* Replace enum for input indices with symbolic constants

* Update intelgpu/visualize_tree.cpp to deal with removal of BatchNormBase
parent 0ad9ec2c
This diff is collapsed.
......@@ -27,28 +27,7 @@ namespace ngraph
{
namespace op
{
class BatchNormBase : public Op
{
public:
BatchNormBase(const std::string& node_type, double eps, const NodeVector& args);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
protected:
enum
{
GAMMA,
BETA,
INPUT,
MEAN,
VARIANCE,
DELTA
};
double m_epsilon;
};
class BatchNormTraining : public BatchNormBase
class BatchNormTraining : public Op
{
public:
// In this version of BatchNorm:
......@@ -102,15 +81,23 @@ namespace ngraph
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
double m_epsilon;
};
class BatchNormInference : public BatchNormBase
class BatchNormInference : public Op
{
public:
// In this version of BatchNorm:
......@@ -139,6 +126,7 @@ namespace ngraph
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -148,6 +136,15 @@ namespace ngraph
{
throw ngraph_error("Invalid operation");
}
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
double m_epsilon;
};
class BatchNormTrainingBackprop : public Op
......@@ -163,22 +160,19 @@ namespace ngraph
void validate_and_infer_types() override;
double get_eps_value() const { return epsilon; }
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
private:
enum
{
GAMMA,
BETA,
INPUT,
MEAN,
VARIANCE,
DELTA
};
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
static constexpr size_t INPUT_DELTA = 5;
double epsilon;
double m_epsilon;
};
}
}
......@@ -609,7 +609,7 @@ namespace ngraph
auto input_rank = input_shape.size();
if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
{
auto batchnorm = static_cast<op::BatchNormBase*>(node);
auto batchnorm = static_cast<op::Op*>(node);
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
......
......@@ -833,7 +833,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
auto bn_pred = pattern::has_class<op::BatchNormBase>();
auto bn_pred = [](std::shared_ptr<Node> node) {
return pattern::has_class<op::BatchNormInference>()(node) ||
pattern::has_class<op::BatchNormTraining>()(node);
};
auto bn = std::make_shared<pattern::op::Any>(
input, bn_pred, NodeVector{gamma, beta, input, mean, var});
auto prelu = std::make_shared<op::Relu>(bn);
......
......@@ -328,10 +328,10 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
writer.block_end();
}
static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
template <typename T>
void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
{
const ngraph::op::BatchNormBase* batchnorm =
static_cast<const ngraph::op::BatchNormBase*>(node);
const T* batchnorm = static_cast<const T*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
......@@ -361,19 +361,20 @@ static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction
void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
{
::emit_BatchNorm(
::emit_BatchNorm<ngraph::op::BatchNormInference>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Inference, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
{
::emit_BatchNorm(
::emit_BatchNorm<ngraph::op::BatchNormTraining>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS)
{
::emit_BatchNorm(external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true);
::emit_BatchNorm<ngraph::op::gpu::BatchNormTrainingWithStats>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
......
......@@ -123,10 +123,17 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
break;
}
case OP_TYPEID::BatchNormInference:
{
const shared_ptr<op::BatchNormInference> batch_norm =
static_pointer_cast<op::BatchNormInference>(node);
writer << print_table_row_value("EPS", batch_norm->get_eps_value());
break;
}
case OP_TYPEID::BatchNormTraining:
{
const shared_ptr<op::BatchNormBase> batch_norm =
static_pointer_cast<op::BatchNormBase>(node);
const shared_ptr<op::BatchNormTraining> batch_norm =
static_pointer_cast<op::BatchNormTraining>(node);
writer << print_table_row_value("EPS", batch_norm->get_eps_value());
break;
......
......@@ -324,3 +324,123 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
return data_batch_output_shape;
}
struct ChannelShapedInputSpec
{
element::Type m_element_type;
PartialShape m_shape;
std::string m_input_name;
};
static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_forward_helper(
const Node* node,
element::Type input_element_type,
const PartialShape& input_shape,
const std::vector<ChannelShapedInputSpec>& channel_shaped_inputs)
{
// Built up a slash-separated string naming all the channel-shaped inputs, for use in error
// messages.
std::stringstream ss;
bool first = true;
for (auto& inp : channel_shaped_inputs)
{
if (!first)
{
ss << "/";
}
ss << inp.m_input_name;
first = false;
}
std::string channel_input_names = ss.str();
// Infer output element type.
element::Type et_result{input_element_type};
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_result, inp.m_element_type))
<< "Input element types do not match.";
}
// Extract channel dimension from input shape.
Dimension channel_dim{Dimension::dynamic()};
NODE_VALIDATION_ASSERT(node,
input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << input_shape
<< ").";
if (input_shape.rank().is_static())
{
channel_dim = input_shape[1];
}
// Infer gamma/beta/mu/sigma shape, which must be consistent with a vector of size "channel_dim".
PartialShape channel_shape{PartialShape::dynamic()};
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, PartialShape::merge_into(channel_shape, inp.m_shape))
<< "Shapes for " << channel_input_names << " do not match.";
}
NODE_VALIDATION_ASSERT(node, channel_shape.merge_rank(1)) << "Shape for " << channel_input_names
<< " (" << channel_shape
<< ") does not have rank 1.";
NODE_VALIDATION_ASSERT(node, Dimension::merge(channel_dim, channel_dim, channel_shape[0]))
<< "Input channel dimension (" << channel_dim << ") does not match shape for "
<< channel_input_names << " (" << channel_shape << ").";
NODE_VALIDATION_ASSERT(node, channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1)
<< "Channel count must be at least 1.";
// Batch result shape is same as the input shape, except we may possibly have inferred more
// information from the channel count via gamma/beta/etc.
PartialShape batch_result_shape{input_shape};
if (batch_result_shape.rank().is_static())
{
batch_result_shape[1] = channel_dim;
}
return std::make_tuple(et_result, batch_result_shape, PartialShape{channel_dim});
}
std::tuple<element::Type, PartialShape, PartialShape>
ngraph::infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
element::Type mean_element_type,
element::Type variance_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape,
const PartialShape& mean_shape,
const PartialShape& variance_shape)
{
return infer_batch_norm_forward_helper(node,
input_element_type,
input_shape,
{{gamma_element_type, gamma_shape, "gamma"},
{beta_element_type, beta_shape, "beta"},
{mean_element_type, mean_shape, "mean"},
{variance_element_type, variance_shape, "variance"}});
}
std::tuple<element::Type, PartialShape, PartialShape>
ngraph::infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape)
{
return infer_batch_norm_forward_helper(
node,
input_element_type,
input_shape,
{{gamma_element_type, gamma_shape, "gamma"}, {beta_element_type, beta_shape, "beta"}});
}
......@@ -52,4 +52,26 @@ namespace ngraph
const PartialShape& window_shape,
const Strides& window_strides,
bool is_window_all_in_padding_allowed);
std::tuple<element::Type, PartialShape, PartialShape>
infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
element::Type mean_element_type,
element::Type variance_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape,
const PartialShape& mean_shape,
const PartialShape& variance_shape);
std::tuple<element::Type, PartialShape, PartialShape>
infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape);
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment