Unverified Commit 403a09ce authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/bnorm back (#2129)

* Fix batchnorm argument order, cleanup some comments, fix backprop

* Merge error

* Clean up training function, organize inference test

* BatchNormInference tests

* Training case

* Training test

* Fix autodiff BatchNorm test

* Cleanup

* Move file to doc checkout

* Update disabled test name in igpu manifest
Fix unnused variable

* Unit tests disables

* Review comments
parent ef7e3f3b
...@@ -924,9 +924,9 @@ def batch_norm(eps, # type: float ...@@ -924,9 +924,9 @@ def batch_norm(eps, # type: float
# type: (...) -> Node # type: (...) -> Node
"""Return batch normalization node.""" """Return batch normalization node."""
if mean is None and variance is None: if mean is None and variance is None:
return BatchNormTraining(eps, gamma, beta, data) return BatchNormTraining(data, gamma, beta, eps)
else: else:
return BatchNormInference(eps, gamma, beta, data, mean, variance) return BatchNormInference(data, gamma, beta, mean, variance, eps)
@nameable_op @nameable_op
......
...@@ -30,10 +30,10 @@ void regclass_pyngraph_op_BatchNormTraining(py::module m) ...@@ -30,10 +30,10 @@ void regclass_pyngraph_op_BatchNormTraining(py::module m)
batch_norm_training(m, "BatchNormTraining"); batch_norm_training(m, "BatchNormTraining");
batch_norm_training.doc() = batch_norm_training.doc() =
"ngraph.impl.op.BatchNormTraining wraps ngraph::op::BatchNormTraining"; "ngraph.impl.op.BatchNormTraining wraps ngraph::op::BatchNormTraining";
batch_norm_training.def(py::init<double, batch_norm_training.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>()); double>());
} }
void regclass_pyngraph_op_BatchNormInference(py::module m) void regclass_pyngraph_op_BatchNormInference(py::module m)
...@@ -45,12 +45,12 @@ void regclass_pyngraph_op_BatchNormInference(py::module m) ...@@ -45,12 +45,12 @@ void regclass_pyngraph_op_BatchNormInference(py::module m)
batch_norm_inference.doc() = batch_norm_inference.doc() =
"ngraph.impl.op.BatchNormInference wraps ngraph::op::BatchNormInference"; "ngraph.impl.op.BatchNormInference wraps ngraph::op::BatchNormInference";
batch_norm_inference.def(py::init<double, batch_norm_inference.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>()); double>());
} }
void regclass_pyngraph_op_BatchNormTrainingBackprop(py::module m) void regclass_pyngraph_op_BatchNormTrainingBackprop(py::module m)
...@@ -61,11 +61,11 @@ void regclass_pyngraph_op_BatchNormTrainingBackprop(py::module m) ...@@ -61,11 +61,11 @@ void regclass_pyngraph_op_BatchNormTrainingBackprop(py::module m)
batch_norm_training_backprop(m, "BatchNormTrainingBackprop"); batch_norm_training_backprop(m, "BatchNormTrainingBackprop");
batch_norm_training_backprop.doc() = batch_norm_training_backprop.doc() =
"ngraph.impl.op.BatchNormTrainingBackprop wraps ngraph::op::BatchNormTrainingBackprop"; "ngraph.impl.op.BatchNormTrainingBackprop wraps ngraph::op::BatchNormTrainingBackprop";
batch_norm_training_backprop.def(py::init<double, batch_norm_training_backprop.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&>()); double>());
} }
...@@ -60,10 +60,10 @@ namespace ngraph ...@@ -60,10 +60,10 @@ namespace ngraph
void replace_output(Output& output); void replace_output(Output& output);
protected: protected:
/// \return the tensor view for the connected output /// \return the tensor for the connected output
std::shared_ptr<const Tensor> get_tensor_ptr() const; std::shared_ptr<const Tensor> get_tensor_ptr() const;
/// \return the tensor view for the connected output /// \return the tensor for the connected output
std::shared_ptr<Tensor> get_tensor_ptr(); std::shared_ptr<Tensor> get_tensor_ptr();
public: public:
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
{ {
namespace layout namespace layout
{ {
/// \brief Interface for describing implementations of tensor views. /// \brief Interface for describing implementations of tensors.
/// ///
/// Kernel selection will need to pay attention to the layout. /// Kernel selection will need to pay attention to the layout.
class TensorLayout class TensorLayout
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
public: public:
virtual ~TensorLayout() {} virtual ~TensorLayout() {}
/// Extent of this view in buffer. /// Extent of this tensor in buffer.
/// ///
/// When we support non-linear buffers, this will need to be something other than size_t. /// When we support non-linear buffers, this will need to be something other than size_t.
size_t get_size() const; size_t get_size() const;
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
public: public:
/// \param node Node that owns this output. /// \param node Node that owns this output.
/// \param index Position of the output tensor in all output tensors /// \param index Position of the output tensor in all output tensors
/// \param tensor The view of this tensor; where the value will be written /// \param tensor The tensor where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor); Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
class TensorLayout; class TensorLayout;
} }
/// \brief Compile-time descriptor of a first-class value that is a view of a tensor. /// \brief Compile-time descriptor of a first-class value that is a tensor.
class Tensor class Tensor
{ {
Tensor(const Tensor&) = delete; Tensor(const Tensor&) = delete;
......
...@@ -53,11 +53,11 @@ namespace ngraph ...@@ -53,11 +53,11 @@ namespace ngraph
mean = inputs.at(3); mean = inputs.at(3);
var = inputs.at(4); var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNormInference>( return {std::make_shared<ngraph::op::BatchNormInference>(
epsilon, scale, bias, x, mean, var)}; x, scale, bias, mean, var, epsilon)};
} }
return { return {
std::make_shared<ngraph::op::BatchNormTraining>(epsilon, scale, bias, x)}; std::make_shared<ngraph::op::BatchNormTraining>(x, scale, bias, epsilon)};
} }
} // namespace set_1 } // namespace set_1
......
This diff is collapsed.
...@@ -27,9 +27,20 @@ namespace ngraph ...@@ -27,9 +27,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
// \brief Batchnorm for training operation
class BatchNormTraining : public Op class BatchNormTraining : public Op
{ {
public: public:
// \param input Must have rank >= 2, [., C, ...]
// \param gamma gamma scaling for normalized value. [C]
// \param beta bias added to the scaled normalized value [C]
// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
double epsilon);
// \deprecated
// In this version of BatchNorm: // In this version of BatchNorm:
// //
// MEAN AND VARIANCE: computed directly from the content of 'input'. // MEAN AND VARIANCE: computed directly from the content of 'input'.
...@@ -75,6 +86,20 @@ namespace ngraph ...@@ -75,6 +86,20 @@ namespace ngraph
class BatchNormInference : public Op class BatchNormInference : public Op
{ {
public: public:
// \param input [., C, ...]
// \param gamma gamma scaling for normalized value. [C]
// \param beta bias added to the scaled normalized value [C]
// \param mean value for mean normalization [C]
// \param variance value for variance normalization [C]
// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
double epsilon);
// \deprecated
// In this version of BatchNorm: // In this version of BatchNorm:
// //
// MEAN AND VARIANCE: provided by the 'mean' and 'variance' parameters. // MEAN AND VARIANCE: provided by the 'mean' and 'variance' parameters.
...@@ -125,10 +150,20 @@ namespace ngraph ...@@ -125,10 +150,20 @@ namespace ngraph
class BatchNormTrainingBackprop : public Op class BatchNormTrainingBackprop : public Op
{ {
public: public:
BatchNormTrainingBackprop(double eps, BatchNormTrainingBackprop(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta,
double epsilon);
// \deprecated
BatchNormTrainingBackprop(double epsilon,
std::shared_ptr<Node> gamma, std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta, std::shared_ptr<Node> beta,
std::shared_ptr<Node> input, std::shared_ptr<Node> input,
std::shared_ptr<Node> mean, std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance, std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta); std::shared_ptr<Node> delta);
......
...@@ -172,13 +172,12 @@ namespace ngraph ...@@ -172,13 +172,12 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
std::function<decltype( std::function<decltype(runtime::cpu::kernel::batch_norm_training<float>)>
runtime::cpu::kernel::batch_norm_three_outputs<float>)>
kernel; kernel;
SELECT_KERNEL(kernel, SELECT_KERNEL(kernel,
args[0].get_element_type(), args[0].get_element_type(),
runtime::cpu::kernel::batch_norm_three_outputs); runtime::cpu::kernel::batch_norm_training);
auto arg2_shape = args[2].get_shape(); auto arg2_shape = args[2].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
...@@ -207,12 +206,12 @@ namespace ngraph ...@@ -207,12 +206,12 @@ namespace ngraph
{ {
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
std::function<decltype(runtime::cpu::kernel::batch_norm_one_output<float>)> std::function<decltype(runtime::cpu::kernel::batch_norm_inference<float>)>
kernel; kernel;
SELECT_KERNEL(kernel, SELECT_KERNEL(kernel,
args[0].get_element_type(), args[0].get_element_type(),
runtime::cpu::kernel::batch_norm_one_output); runtime::cpu::kernel::batch_norm_inference);
auto arg2_shape = args[2].get_shape(); auto arg2_shape = args[2].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
...@@ -255,12 +254,12 @@ namespace ngraph ...@@ -255,12 +254,12 @@ namespace ngraph
auto& functors = external_function->get_functors(); auto& functors = external_function->get_functors();
std::function<decltype(runtime::cpu::kernel::batch_norm_one_output<float>)> std::function<decltype(runtime::cpu::kernel::batch_norm_inference<float>)>
kernel; kernel;
SELECT_KERNEL(kernel, SELECT_KERNEL(kernel,
args[0].get_element_type(), args[0].get_element_type(),
runtime::cpu::kernel::batch_norm_one_output); runtime::cpu::kernel::batch_norm_inference);
auto arg2_shape = args[2].get_shape(); auto arg2_shape = args[2].get_shape();
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name()); auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
namespace kernel namespace kernel
{ {
template <typename ElementType> template <typename ElementType>
void batch_norm_three_outputs(double eps, void batch_norm_training(double eps,
const void* arg0, const void* arg0,
const void* arg1, const void* arg1,
const void* arg2, const void* arg2,
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
void* out2, void* out2,
const Shape& arg2_shape) const Shape& arg2_shape)
{ {
reference::batch_norm_three_outputs(eps, reference::batch_norm_training(eps,
static_cast<const ElementType*>(arg0), static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1), static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2), static_cast<const ElementType*>(arg2),
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
} }
template <typename ElementType> template <typename ElementType>
void batch_norm_one_output(double eps, void batch_norm_inference(double eps,
const void* arg0, const void* arg0,
const void* arg1, const void* arg1,
const void* arg2, const void* arg2,
...@@ -57,7 +57,7 @@ namespace ngraph ...@@ -57,7 +57,7 @@ namespace ngraph
void* out0, void* out0,
const Shape& arg2_shape) const Shape& arg2_shape)
{ {
reference::batch_norm_one_output(eps, reference::batch_norm_inference(eps,
static_cast<const ElementType*>(arg0), static_cast<const ElementType*>(arg0),
static_cast<const ElementType*>(arg1), static_cast<const ElementType*>(arg1),
static_cast<const ElementType*>(arg2), static_cast<const ElementType*>(arg2),
......
...@@ -6,8 +6,8 @@ one_hot_vector_1_barely_oob ...@@ -6,8 +6,8 @@ one_hot_vector_1_barely_oob
one_hot_vector_1_far_oob one_hot_vector_1_far_oob
one_hot_vector_1_fp_nonint one_hot_vector_1_fp_nonint
backwards_batch_norm_three_outputs
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3 backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
backwards_batch_norm_training
shape_of_scalar shape_of_scalar
shape_of_vector shape_of_vector
shape_of_matrix shape_of_matrix
......
...@@ -30,6 +30,13 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max ...@@ -30,6 +30,13 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_avgpool_n1_c1_hw2x2 backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4 backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4 backwards_avgpool_n2_c2_hw4x4
batch_norm_inference_0eps_f64
batch_norm_inference_0eps_f32
batch_norm_inference_f64
batch_norm_inference_f32
batch_norm_training_0eps_f64
batch_norm_training_0eps_f32
backwards_batch_norm_training
dequantize dequantize
dequantize_zero_offset dequantize_zero_offset
dequantize_axes dequantize_axes
......
avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation avg_pool_2d_2channel_2image_padded_only_above_do_not_include_in_computation
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded avg_pool_3d_uneven_strided_padded
backwards_batch_norm_three_outputs backwards_batch_norm_training
backwards_dot_scalar_tensor backwards_dot_scalar_tensor
backwards_dot_tensor3_tensor3 backwards_dot_tensor3_tensor3
backwards_dot_tensor_scalar backwards_dot_tensor_scalar
...@@ -15,6 +15,9 @@ backwards_reverse_sequence_n3_c2_h3 ...@@ -15,6 +15,9 @@ backwards_reverse_sequence_n3_c2_h3
backwards_reverse_sequence_n4d2c3h2w2 backwards_reverse_sequence_n4d2c3h2w2
backwards_slice backwards_slice
backwards_tanh backwards_tanh
batch_norm_inference_0eps_f64
batch_norm_inference_f64
batch_norm_training_0eps_f64
batch_norm_one_output batch_norm_one_output
batch_norm_three_outputs batch_norm_three_outputs
dequantize dequantize
......
...@@ -364,9 +364,7 @@ private: ...@@ -364,9 +364,7 @@ private:
{ {
const ngraph::op::BatchNormTraining* bn = const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node); static_cast<const ngraph::op::BatchNormTraining*>(&node);
if (bn->get_output_size() == 3) reference::batch_norm_training<T>(bn->get_eps_value(),
{
reference::batch_norm_three_outputs<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]), static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]), static_cast<const T*>(args[2]),
...@@ -374,25 +372,13 @@ private: ...@@ -374,25 +372,13 @@ private:
static_cast<T*>(out[1]), static_cast<T*>(out[1]),
static_cast<T*>(out[2]), static_cast<T*>(out[2]),
node.get_input_shape(2)); node.get_input_shape(2));
}
else
{
reference::batch_norm_one_output<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]),
static_cast<const T*>(args[3]),
static_cast<const T*>(args[4]),
static_cast<T*>(out[0]),
node.get_input_shape(2));
}
break; break;
} }
case OP_TYPEID::BatchNormInference: case OP_TYPEID::BatchNormInference:
{ {
const ngraph::op::BatchNormInference* bn = const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node); static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_one_output<T>(bn->get_eps_value(), reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), static_cast<const T*>(args[0]),
static_cast<const T*>(args[1]), static_cast<const T*>(args[1]),
static_cast<const T*>(args[2]), static_cast<const T*>(args[2]),
......
batchnorm_bprop_n4c3h2w2
batchnorm_fprop_b1c2h2w2
batchnorm_fprop_b2c2h2w1
batchnorm_fprop_globalstats_b2c2w2h1
batchnorm_fprop_inference_b2c2h2w1
batchnorm_fprop_bprop
batchnorm_fprop_bprop_2step
This diff is collapsed.
...@@ -531,21 +531,24 @@ static shared_ptr<ngraph::Function> ...@@ -531,21 +531,24 @@ static shared_ptr<ngraph::Function>
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
node = make_shared<op::BatchNormTraining>(epsilon, args[0], args[1], args[2]); // Odd order for back-compatibility
node = make_shared<op::BatchNormTraining>(args[2], args[0], args[1], epsilon);
break; break;
} }
case OP_TYPEID::BatchNormInference: case OP_TYPEID::BatchNormInference:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
// Odd order for back-compatibility
node = make_shared<op::BatchNormInference>( node = make_shared<op::BatchNormInference>(
epsilon, args[0], args[1], args[2], args[3], args[4]); args[2], args[0], args[1], args[3], args[4], epsilon);
break; break;
} }
case OP_TYPEID::BatchNormTrainingBackprop: case OP_TYPEID::BatchNormTrainingBackprop:
{ {
auto epsilon = node_js.at("eps").get<double>(); auto epsilon = node_js.at("eps").get<double>();
// Odd order for back-compatibility
node = make_shared<op::BatchNormTrainingBackprop>( node = make_shared<op::BatchNormTrainingBackprop>(
epsilon, args[0], args[1], args[2], args[3], args[4], args[5]); args[2], args[0], args[1], args[3], args[4], args[5], epsilon);
break; break;
} }
case OP_TYPEID::Broadcast: case OP_TYPEID::Broadcast:
......
...@@ -1624,39 +1624,42 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2) ...@@ -1624,39 +1624,42 @@ NGRAPH_TEST(${BACKEND_NAME}, backwards_maxpool_n2c1h5w5_kh3kw3_sh2sw2)
ASSERT_TRUE(read_vector<float>(output) == expected); ASSERT_TRUE(read_vector<float>(output) == expected);
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_three_outputs) NGRAPH_TEST(${BACKEND_NAME}, backwards_batch_norm_training)
{ {
auto shape_in = Shape{2, 3, 1, 1}; const Shape input_shape{5, 3, 2, 2};
auto shape_mean = Shape{3}; const Shape channel_shape{input_shape.at(1)};
const double eps = 1e-3;
const element::Type& et = element::f32;
using T = float;
// we need to keep GOEs for mean and variance alive // Need to keep the output elements for mean and variance from going out of scope
// even though those aren't used as outputs for fprop // and getting freed.
// they are needed for a bprop pass
NodeVector goes; NodeVector goes;
auto make_graph = [&goes, shape_in, shape_mean] { auto make_graph = [&input_shape, &channel_shape, &eps, &et, &goes] {
auto A = make_shared<op::Parameter>(element::f64, shape_in); auto input = make_shared<op::Parameter>(et, input_shape);
auto B = make_shared<op::Parameter>(element::f64, shape_mean); auto gamma = make_shared<op::Parameter>(et, channel_shape);
auto C = make_shared<op::Parameter>(element::f64, shape_mean); auto beta = make_shared<op::Parameter>(et, channel_shape);
auto BN = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto BN = make_shared<op::BatchNormTraining>(1e-3, B, C, A); auto normed_input = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 0));
// make sure we create GOEs for mean and variance needed for bprop auto mean = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 1));
goes.push_back(make_shared<op::GetOutputElement>(BN, 1)); auto variance = make_shared<op::Result>(make_shared<op::GetOutputElement>(BN, 2));
goes.push_back(make_shared<op::GetOutputElement>(BN, 2)); goes.push_back(mean);
goes.push_back(variance);
auto f = make_shared<Function>(make_shared<op::GetOutputElement>(BN, 0), // TODO autodiff testing with more than one result
ParameterVector{A, B, C}); auto f = make_shared<Function>(ResultVector{normed_input /* , mean, variance*/},
ParameterVector{input, gamma, beta});
return f; return f;
}; };
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
test::Uniform<double> rng(-1.0, 1.0); test::Uniform<T> rng(-1.0, 1.0);
auto x0 = rng.initialize(backend->create_tensor<double>(shape_in)); auto input = rng.initialize(backend->create_tensor<T>(input_shape));
auto x1 = rng.initialize(backend->create_tensor<double>(shape_mean)); auto gamma = rng.initialize(backend->create_tensor<T>(channel_shape));
auto x2 = rng.initialize(backend->create_tensor<double>(shape_mean)); auto beta = rng.initialize(backend->create_tensor<T>(channel_shape));
EXPECT_TRUE( EXPECT_TRUE(
autodiff_numeric_compare<double>(backend.get(), make_graph, {x0, x1, x2}, .01, .01)); autodiff_numeric_compare<T>(backend.get(), make_graph, {input, gamma, beta}, .001, .001));
} }
NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3) NGRAPH_TEST(${BACKEND_NAME}, backwards_reverse_sequence_n3_c2_h3)
......
This diff is collapsed.
...@@ -719,7 +719,7 @@ TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2) ...@@ -719,7 +719,7 @@ TEST(cpu_fusion, batchnorm_fprop_relu_b1c2h2w2)
auto beta = make_shared<op::Parameter>(element::f32, beta_shape); auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001; double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2}; auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNormTraining>(eps, gamma, beta, input); auto bn = make_shared<op::BatchNormTraining>(input, gamma, beta, eps);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0); auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
// Note, op::Splice is used to break Relu(BatchNorm) fusion // Note, op::Splice is used to break Relu(BatchNorm) fusion
...@@ -1083,8 +1083,8 @@ shared_ptr<Function> gen_groupconv_batchnorm(const bool add_goe, ...@@ -1083,8 +1083,8 @@ shared_ptr<Function> gen_groupconv_batchnorm(const bool add_goe,
// Adding a goe will stop fusion since the patterns wont expect to see this op // Adding a goe will stop fusion since the patterns wont expect to see this op
auto bn = auto bn =
add_goe ? std::make_shared<op::BatchNormInference>(eps, gamma, beta, goe_bn, mean, var) add_goe ? std::make_shared<op::BatchNormInference>(goe_bn, gamma, beta, mean, var, eps)
: std::make_shared<op::BatchNormInference>(eps, gamma, beta, group_conv, mean, var); : std::make_shared<op::BatchNormInference>(group_conv, gamma, beta, mean, var, eps);
if (with_relu) if (with_relu)
{ {
auto prelu = std::make_shared<op::Relu>(bn); auto prelu = std::make_shared<op::Relu>(bn);
...@@ -1768,7 +1768,7 @@ TEST(cpu_fusion, conv_batch_norm_folding) ...@@ -1768,7 +1768,7 @@ TEST(cpu_fusion, conv_batch_norm_folding)
auto mean = std::make_shared<op::Parameter>(element::f32, shape_norm); auto mean = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto var = std::make_shared<op::Parameter>(element::f32, shape_norm); auto var = std::make_shared<op::Parameter>(element::f32, shape_norm);
auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1}); auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1});
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, conv, mean, var); auto bn = std::make_shared<op::BatchNormInference>(conv, gamma, beta, mean, var, eps);
auto f = make_shared<Function>(NodeVector{bn}, auto f = make_shared<Function>(NodeVector{bn},
ParameterVector{input, weights, gamma, beta, mean, var}); ParameterVector{input, weights, gamma, beta, mean, var});
return f; return f;
...@@ -1830,7 +1830,7 @@ TEST(cpu_fusion, convbias_batch_norm_folding) ...@@ -1830,7 +1830,7 @@ TEST(cpu_fusion, convbias_batch_norm_folding)
auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1}); auto conv = std::make_shared<op::Convolution>(input, weights, Strides{1, 1}, Strides{1, 1});
auto convbias = auto convbias =
conv + std::make_shared<op::Broadcast>(bias, conv->get_shape(), AxisSet{0, 2, 3}); conv + std::make_shared<op::Broadcast>(bias, conv->get_shape(), AxisSet{0, 2, 3});
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, convbias, mean, var); auto bn = std::make_shared<op::BatchNormInference>(convbias, gamma, beta, mean, var, eps);
auto f = make_shared<Function>( auto f = make_shared<Function>(
NodeVector{bn}, ParameterVector{input, weights, bias, gamma, beta, mean, var}); NodeVector{bn}, ParameterVector{input, weights, bias, gamma, beta, mean, var});
return f; return f;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment