Commit 722903ac authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Partial Shapes and Types, Part 4m: BatchNorm and backprops (#1904)

* Implement partial shape/type propagation for Convolution; fail for want of unit tests

* Implement unit tests for partial shapes/types for Convolution

* Implement partial shape/type validation for BatchNormInference, BatchNormTraining, BatchNormTrainingBackprop with unit tests

* Formatting

* Update CPU and GPU backends to deal with elimination of BatchNormBase

* Update BatchNormTrainingWithStats to use templated emit_BatchNorm function

* Restore five-argument BatchNormTraining ctor for now; #1901 will eliminate it

* Replace enum for input indices with symbolic constants

* Update intelgpu/visualize_tree.cpp to deal with removal of BatchNormBase
parent 0ad9ec2c
This diff is collapsed.
...@@ -27,28 +27,7 @@ namespace ngraph ...@@ -27,28 +27,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class BatchNormBase : public Op class BatchNormTraining : public Op
{
public:
BatchNormBase(const std::string& node_type, double eps, const NodeVector& args);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
protected:
enum
{
GAMMA,
BETA,
INPUT,
MEAN,
VARIANCE,
DELTA
};
double m_epsilon;
};
class BatchNormTraining : public BatchNormBase
{ {
public: public:
// In this version of BatchNorm: // In this version of BatchNorm:
...@@ -102,15 +81,23 @@ namespace ngraph ...@@ -102,15 +81,23 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
double m_epsilon;
}; };
class BatchNormInference : public BatchNormBase class BatchNormInference : public Op
{ {
public: public:
// In this version of BatchNorm: // In this version of BatchNorm:
...@@ -139,6 +126,7 @@ namespace ngraph ...@@ -139,6 +126,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -148,6 +136,15 @@ namespace ngraph ...@@ -148,6 +136,15 @@ namespace ngraph
{ {
throw ngraph_error("Invalid operation"); throw ngraph_error("Invalid operation");
} }
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
static constexpr size_t INPUT_MEAN = 3;
static constexpr size_t INPUT_VARIANCE = 4;
double m_epsilon;
}; };
class BatchNormTrainingBackprop : public Op class BatchNormTrainingBackprop : public Op
...@@ -163,22 +160,19 @@ namespace ngraph ...@@ -163,22 +160,19 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return epsilon; } double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
private: private:
enum static constexpr size_t INPUT_GAMMA = 0;
{ static constexpr size_t INPUT_BETA = 1;
GAMMA, static constexpr size_t INPUT_DATA = 2;
BETA, static constexpr size_t INPUT_MEAN = 3;
INPUT, static constexpr size_t INPUT_VARIANCE = 4;
MEAN, static constexpr size_t INPUT_DELTA = 5;
VARIANCE,
DELTA double m_epsilon;
};
double epsilon;
}; };
} }
} }
...@@ -609,7 +609,7 @@ namespace ngraph ...@@ -609,7 +609,7 @@ namespace ngraph
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
if ((input_rank == 4 && node->get_input_element_type(2) == element::f32)) if ((input_rank == 4 && node->get_input_element_type(2) == element::f32))
{ {
auto batchnorm = static_cast<op::BatchNormBase*>(node); auto batchnorm = static_cast<op::Op*>(node);
auto op_annotations = auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>(); std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
......
...@@ -833,7 +833,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta ...@@ -833,7 +833,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape); auto gamma = std::make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2}; auto beta_shape = Shape{2};
auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape); auto beta = std::make_shared<pattern::op::Label>(element::f32, beta_shape);
auto bn_pred = pattern::has_class<op::BatchNormBase>(); auto bn_pred = [](std::shared_ptr<Node> node) {
return pattern::has_class<op::BatchNormInference>()(node) ||
pattern::has_class<op::BatchNormTraining>()(node);
};
auto bn = std::make_shared<pattern::op::Any>( auto bn = std::make_shared<pattern::op::Any>(
input, bn_pred, NodeVector{gamma, beta, input, mean, var}); input, bn_pred, NodeVector{gamma, beta, input, mean, var});
auto prelu = std::make_shared<op::Relu>(bn); auto prelu = std::make_shared<op::Relu>(bn);
......
...@@ -328,10 +328,10 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS) ...@@ -328,10 +328,10 @@ void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
writer.block_end(); writer.block_end();
} }
static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats) template <typename T>
void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
{ {
const ngraph::op::BatchNormBase* batchnorm = const T* batchnorm = static_cast<const T*>(node);
static_cast<const ngraph::op::BatchNormBase*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter(); auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
...@@ -361,19 +361,20 @@ static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction ...@@ -361,19 +361,20 @@ static void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction
void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
{ {
::emit_BatchNorm( ::emit_BatchNorm<ngraph::op::BatchNormInference>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Inference, false); external_function, writer, node, args, out, CUDNNEmitter::Prop::Inference, false);
} }
void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
{ {
::emit_BatchNorm( ::emit_BatchNorm<ngraph::op::BatchNormTraining>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, false); external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, false);
} }
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS)
{ {
::emit_BatchNorm(external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true); ::emit_BatchNorm<ngraph::op::gpu::BatchNormTrainingWithStats>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true);
} }
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
......
...@@ -123,10 +123,17 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node) ...@@ -123,10 +123,17 @@ void print_node_parameters(ostringstream& writer, const shared_ptr<Node>& node)
break; break;
} }
case OP_TYPEID::BatchNormInference: case OP_TYPEID::BatchNormInference:
{
const shared_ptr<op::BatchNormInference> batch_norm =
static_pointer_cast<op::BatchNormInference>(node);
writer << print_table_row_value("EPS", batch_norm->get_eps_value());
break;
}
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
const shared_ptr<op::BatchNormBase> batch_norm = const shared_ptr<op::BatchNormTraining> batch_norm =
static_pointer_cast<op::BatchNormBase>(node); static_pointer_cast<op::BatchNormTraining>(node);
writer << print_table_row_value("EPS", batch_norm->get_eps_value()); writer << print_table_row_value("EPS", batch_norm->get_eps_value());
break; break;
......
...@@ -324,3 +324,123 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node, ...@@ -324,3 +324,123 @@ PartialShape ngraph::infer_batched_pooling_forward(const Node* node,
return data_batch_output_shape; return data_batch_output_shape;
} }
struct ChannelShapedInputSpec
{
element::Type m_element_type;
PartialShape m_shape;
std::string m_input_name;
};
static std::tuple<element::Type, PartialShape, PartialShape> infer_batch_norm_forward_helper(
const Node* node,
element::Type input_element_type,
const PartialShape& input_shape,
const std::vector<ChannelShapedInputSpec>& channel_shaped_inputs)
{
// Built up a slash-separated string naming all the channel-shaped inputs, for use in error
// messages.
std::stringstream ss;
bool first = true;
for (auto& inp : channel_shaped_inputs)
{
if (!first)
{
ss << "/";
}
ss << inp.m_input_name;
first = false;
}
std::string channel_input_names = ss.str();
// Infer output element type.
element::Type et_result{input_element_type};
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, element::Type::merge(et_result, et_result, inp.m_element_type))
<< "Input element types do not match.";
}
// Extract channel dimension from input shape.
Dimension channel_dim{Dimension::dynamic()};
NODE_VALIDATION_ASSERT(node,
input_shape.is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 2)
<< "Input argument must have rank of at least 2 (input argument shape: " << input_shape
<< ").";
if (input_shape.rank().is_static())
{
channel_dim = input_shape[1];
}
// Infer gamma/beta/mu/sigma shape, which must be consistent with a vector of size "channel_dim".
PartialShape channel_shape{PartialShape::dynamic()};
for (auto& inp : channel_shaped_inputs)
{
NODE_VALIDATION_ASSERT(node, PartialShape::merge_into(channel_shape, inp.m_shape))
<< "Shapes for " << channel_input_names << " do not match.";
}
NODE_VALIDATION_ASSERT(node, channel_shape.merge_rank(1)) << "Shape for " << channel_input_names
<< " (" << channel_shape
<< ") does not have rank 1.";
NODE_VALIDATION_ASSERT(node, Dimension::merge(channel_dim, channel_dim, channel_shape[0]))
<< "Input channel dimension (" << channel_dim << ") does not match shape for "
<< channel_input_names << " (" << channel_shape << ").";
NODE_VALIDATION_ASSERT(node, channel_dim.is_dynamic() || static_cast<size_t>(channel_dim) >= 1)
<< "Channel count must be at least 1.";
// Batch result shape is same as the input shape, except we may possibly have inferred more
// information from the channel count via gamma/beta/etc.
PartialShape batch_result_shape{input_shape};
if (batch_result_shape.rank().is_static())
{
batch_result_shape[1] = channel_dim;
}
return std::make_tuple(et_result, batch_result_shape, PartialShape{channel_dim});
}
std::tuple<element::Type, PartialShape, PartialShape>
ngraph::infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
element::Type mean_element_type,
element::Type variance_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape,
const PartialShape& mean_shape,
const PartialShape& variance_shape)
{
return infer_batch_norm_forward_helper(node,
input_element_type,
input_shape,
{{gamma_element_type, gamma_shape, "gamma"},
{beta_element_type, beta_shape, "beta"},
{mean_element_type, mean_shape, "mean"},
{variance_element_type, variance_shape, "variance"}});
}
std::tuple<element::Type, PartialShape, PartialShape>
ngraph::infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape)
{
return infer_batch_norm_forward_helper(
node,
input_element_type,
input_shape,
{{gamma_element_type, gamma_shape, "gamma"}, {beta_element_type, beta_shape, "beta"}});
}
...@@ -52,4 +52,26 @@ namespace ngraph ...@@ -52,4 +52,26 @@ namespace ngraph
const PartialShape& window_shape, const PartialShape& window_shape,
const Strides& window_strides, const Strides& window_strides,
bool is_window_all_in_padding_allowed); bool is_window_all_in_padding_allowed);
std::tuple<element::Type, PartialShape, PartialShape>
infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
element::Type mean_element_type,
element::Type variance_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape,
const PartialShape& mean_shape,
const PartialShape& variance_shape);
std::tuple<element::Type, PartialShape, PartialShape>
infer_batch_norm_forward(const Node* node,
element::Type input_element_type,
element::Type gamma_element_type,
element::Type beta_element_type,
const PartialShape& input_shape,
const PartialShape& gamma_shape,
const PartialShape& beta_shape);
} }
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment