Commit 23ac5e5a authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

CUDNN BatchNorm (inference/forward/backward) (#893)

* Added cudnn batch norm operation to GPU transformer.
Brought batchnorm tests out of cpu_tests and into
backend_tests. Need to add JIRA ticket for interpreter
SKIPS.

* CUDNN batchnorm is implemented. In the ForwardTraining branch
CUDNN seems to calculate the batch mean correctly but the batch variance incorrectly.
Currently the batchnorm output and mean are calculated correctly for tests:
* GPU.batchnorm_fprop_b2c2h3w3_mean_var
* GPU.batchnorm_fprop_b1c2h2w2
* GPU.batchnorm_fprop_b2c2h2w1
but the variance calculated for the batches in these tests is incorrectly calculated by CUDNN.

Also added an additional test and cleaned up some of the old tests.

* MKLDNN internally utilizes the biased estimate of the population variance
and the tests have been crafted to suit MKLDNN. According to the original
batchnorm publication (https://arxiv.org/pdf/1502.03167v3.pdf), population
(unbiased) statistics should be used for inference, and mini-batch (biased)
statistics should be used training (forward/backward). For the variance this
means utlitizing the following equations, respectively:

  (biased)   Var[X] = 1/m * Sum_i(x_i-mu)^2      :: used in training
  (unbiased) Var[X] = 1/(m-1) * Sum_i(x_i-mu)^2  :: used in inference

  s.t. x_i are elements of X and m = N*D*H*W.

For large batch sizes in inference this may not impact convergence as m >> 1,
but for small batch sizes it will. CUDNN internally utilizes the unbiased
variance.

Changes:
* Added Multiply op to Forward pass of batchnorm to convert
  the unbiased variance to a biased one. The op utilizes the
  blending scaling factors to apply the bias factor.
* Adds emission for the BatchNormBackprop kernel and cleans up
  the emitter implementation.

* Added hashing to cudnn::batchnorm op.

* Formatting.

* Changed hashing of epsilon in cudnn batchnorm.

* Remove implicit conversion and default case in switch for bn.

* Added skips for IE transformer on batchnorm.

* add cudnn include path to compiler.cpp

* seperate two path

* PR #892 and #825 which were recently merged both forgot skips for the GPU backend.
Adding them in as they are unimplemented ops.

* The allocation and deletion of primitives was occuring in seperate
translation units with raw c pointers. Because of this, it was not
clear that these were being freed appropriate, nor did it indicate
ownership of the pointers.

In this commit these raw pointers have been converted over to
std::unique_ptrs such that the construction/destruction is managed
automatically. Furthermore, GPUPrimitiveEmitter::insert now only
takes an r-value reference, requiring move-semantics to indicate
that when inserting a primitive, the GPUPrimitiveEmitter takes
ownership of the pointer.

All instances of primitive creation have been modified.

* CUDNN_SAFE_CALL

* Removed redundant comment and made variable names more verbose.

* Change from conditionals to case-switch in pooling to conform to
batchnorm per @fengleitian's suggestion.
parent b0421577
......@@ -150,12 +150,12 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
gpu::primitive* pad = nullptr;
std::unique_ptr<gpu::primitive> pad;
// if the pad value is statically provided, the kernel call signature is different
if (pad_value == "") // pad value provided at runtime (dynamic)
{
pad = new gpu::primitive{[=](void** inputs, void** outputs) {
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[1], &inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -169,11 +169,11 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
}
else // pad value provided at compile time (static)
{
pad = new gpu::primitive{[=](void** inputs, void** outputs) {
pad.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -187,10 +187,10 @@ size_t runtime::gpu::CUDAEmitter::build_pad(const runtime::gpu::GPURuntimeContex
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
}
primitive_index = this->m_primitive_emitter->insert(pad);
primitive_index = this->m_primitive_emitter->insert(std::move(pad));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......@@ -259,7 +259,7 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
compiled_kernel = ctx->compiled_kernel_pool->set(hash, writer.get_code());
}
auto pool = new gpu::primitive{[=](void** inputs, void** outputs) {
std::unique_ptr<gpu::primitive> pool(new gpu::primitive{[=](void** inputs, void** outputs) {
void* args_list[] = {&inputs[0], &outputs[0]};
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<unsigned int>(nthreads),
......@@ -273,9 +273,9 @@ size_t runtime::gpu::CUDAEmitter::build_1d_max_pool(const GPURuntimeContext* ctx
args_list,
0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}};
}});
primitive_index = this->m_primitive_emitter->insert(pool);
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......
......@@ -199,7 +199,7 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPUR
};
}
// emit sum reduce operation
auto* reduce = new gpu::primitive{
std::unique_ptr<gpu::primitive> reduce(new gpu::primitive{
[ctx, reduce_op, get_input_desc, get_output_desc](void** inputs, void** outputs) {
auto input_desc = get_input_desc();
auto output_desc = get_output_desc();
......@@ -229,12 +229,12 @@ size_t runtime::gpu::CUDNNEmitter::build_reduce_forward(const runtime::gpu::GPUR
output_desc,
outputs[0]));
free_gpu_buffer(workspace_ptr);
}};
}});
return this->m_primitive_emitter->insert(reduce);
return this->m_primitive_emitter->insert(std::move(reduce));
}
size_t runtime::gpu::CUDNNEmitter::build_pooling(const GPURuntimeContext* ctx,
size_t runtime::gpu::CUDNNEmitter::build_pooling(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnPoolingMode_t& pool_op,
const Prop& direction,
const Shape& input_shape,
......@@ -301,10 +301,14 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const GPURuntimeContext* ctx,
throw std::runtime_error("Pooling currently supports up to 3 spatial dimensions only.");
}
gpu::primitive* pool = nullptr;
if (direction == Prop::Forward)
std::unique_ptr<gpu::primitive> pool;
switch (direction)
{
case (Prop::Inference):
case (Prop::Forward):
{
pool = new gpu::primitive{[=](void** inputs, void** outputs) {
pool.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
float alpha = 1.0, beta = 0.0;
CUDNN_SAFE_CALL(cudnnPoolingForward(*ctx->cudnn_handle,
desc,
......@@ -314,11 +318,12 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const GPURuntimeContext* ctx,
&beta,
output_desc,
outputs[0]));
}};
}});
break;
}
else if (direction == Prop::Backward)
case (Prop::Backward):
{
pool = new gpu::primitive{[=](void** inputs, void** outputs) {
pool.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
float alpha = 1.0, beta = 0.0;
// cuDNN requires the output tensor of the maxpool fprop to be passed even though
// it is not mathematically necessary. It appears, however, that it is not actually
......@@ -339,10 +344,152 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const GPURuntimeContext* ctx,
// adjoint of input
input_desc,
outputs[0]));
}};
}});
break;
}
}
primitive_index = this->m_primitive_emitter->insert(std::move(pool));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDNNEmitter::build_batchnorm(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnBatchNormMode_t& bn_op,
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon)
{
// Assumes NC{d1...dN} format
std::stringstream ss;
ss.precision(std::numeric_limits<double>::digits10 + 2);
ss << "bn_op" << bn_op << "_dir" << static_cast<int>(direction) << "_ts"
<< join(tensor_shape, "_") << "_ps" << join(param_shape, "_") << "_eps" << epsilon;
std::string hash = ss.str();
std::replace(hash.begin(), hash.end(), '.', '_');
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
if (epsilon < CUDNN_BN_MIN_EPSILON)
{
throw std::runtime_error("Batch Norm epsilon is less than CUDNN_BN_MIN_EPSILON");
}
cudnnTensorDescriptor_t derived_param_desc;
CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&derived_param_desc));
auto tensor_desc = runtime::gpu::cudnn_util::tensor_descriptor_from_shape(tensor_shape);
CUDNN_SAFE_CALL(cudnnDeriveBNTensorDescriptor(derived_param_desc, tensor_desc, bn_op));
float alpha = 1.0, beta = 0.0;
std::unique_ptr<gpu::primitive> batchnorm;
switch (direction)
{
case Prop::Inference:
{
batchnorm.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardInference(*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
tensor_desc,
inputs[2], // tensor
tensor_desc,
outputs[0], // tensor
derived_param_desc,
inputs[0], // gain
inputs[1], // bias
inputs[3], // mean
inputs[4], // variance
epsilon));
}});
break;
}
case Prop::Forward:
{
cudnnOpTensorDescriptor_t op_desc;
CUDNN_SAFE_CALL(cudnnCreateOpTensorDescriptor(&op_desc));
CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(
op_desc, CUDNN_OP_TENSOR_MUL, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
// currently not using the cudnn moving average
// calculation so this factor needs to be set to 1.0
double exp_avg_factor = 1.0;
// factor to convert unbiased variance to biased variance estimate
// mini-batch statistics (variance of the sample) should be used
// in training and population statistics (sample variance) used
// during inference. see commit note for 3b081ce for more details.
float m = shape_size(tensor_shape) / tensor_shape[1];
float bias_factor = (m - 1) / m;
batchnorm.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationForwardTraining(*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
tensor_desc,
inputs[2],
tensor_desc,
outputs[0],
derived_param_desc,
inputs[0],
inputs[1],
exp_avg_factor,
outputs[1],
outputs[2],
epsilon,
NULL,
NULL));
// convert to biased variance
CUDNN_SAFE_CALL(cudnnOpTensor(*ctx->cudnn_handle,
op_desc,
&beta,
derived_param_desc,
outputs[2],
&beta,
derived_param_desc,
outputs[2],
&bias_factor,
derived_param_desc,
outputs[2]));
}});
break;
}
case Prop::Backward:
{
batchnorm.reset(new gpu::primitive{[=](void** inputs, void** outputs) {
CUDNN_SAFE_CALL(cudnnBatchNormalizationBackward(
*ctx->cudnn_handle,
bn_op,
&alpha,
&beta,
&alpha,
&beta,
tensor_desc,
inputs[2 /* input tensor x */],
tensor_desc,
inputs[5 /* dy */],
tensor_desc,
outputs[0 /* dx */],
derived_param_desc,
inputs[0 /* gamma */],
outputs[1 /* dgamma */],
outputs[2 /* dbeta */],
epsilon,
NULL, // inputs[3 /* mu batch mean*/],
NULL)); // inputs[4 /* 1/sig**2 batch inverse variance*/]);
}});
break;
}
}
primitive_index = this->m_primitive_emitter->insert(pool);
primitive_index = this->m_primitive_emitter->insert(std::move(batchnorm));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
......@@ -50,6 +50,7 @@ namespace ngraph
public:
enum class Prop
{
Inference,
Forward,
Backward
};
......@@ -69,6 +70,13 @@ namespace ngraph
const ngraph::Shape& padding_below,
const ngraph::Shape& padding_above);
size_t build_batchnorm(const runtime::gpu::GPURuntimeContext* ctx,
const cudnnBatchNormMode_t& bn_op,
const Prop& direction,
const Shape& tensor_shape,
const Shape& param_shape,
double epsilon);
private:
CUDNNEmitter(GPUPrimitiveEmitter* emitter);
GPUPrimitiveEmitter* m_primitive_emitter;
......
......@@ -1415,6 +1415,100 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::BatchNorm)
{
const ngraph::op::BatchNorm* batchnorm =
static_cast<const ngraph::op::BatchNorm*>(node);
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
CUDNNEmitter::Prop direction;
if (batchnorm->get_training_flag() && args.size() == 3)
{
direction = CUDNNEmitter::Prop::Forward;
}
else
{
direction = CUDNNEmitter::Prop::Inference;
}
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
direction,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
for (size_t i = 1; i < args.size(); i++)
{
writer << ", " << args[i].get_name();
}
writer << "}.data(), ";
writer << "std::vector<void*>{" << out.front().get_name();
for (size_t i = 1; i < out.size(); i++)
{
writer << ", " << out[i].get_name();
}
writer << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::BatchNormBackprop)
{
const ngraph::op::BatchNormBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormBackprop*>(node);
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto bn_index = cudnn_emitter->build_batchnorm(external_function->ctx().get(),
CUDNN_BATCHNORM_SPATIAL,
CUDNNEmitter::Prop::Backward,
args[2].get_shape(),
args[0].get_shape(),
batchnorm->get_eps_value());
writer.block_begin(" // " + node->get_name());
{
writer << "gpu::invoke_primitive(ctx, " << bn_index << ", ";
writer << "std::vector<void*>{" << args.front().get_name();
for (size_t i = 1; i < args.size(); i++)
{
writer << ", " << args[i].get_name();
}
writer << "}.data(), ";
writer << "std::vector<void*>{" << out.front().get_name();
for (size_t i = 1; i < out.size(); i++)
{
writer << ", " << out[i].get_name();
}
writer << "}.data()";
writer << ");\n";
}
writer.block_end();
}
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::GetOutputElement)
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer.block_begin(" // " + node->get_name());
writer << "runtime::gpu::cuda_memcpyDtH(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
}
// assumes NC{d1,d2,d3,...} format
Shape get_padded_shape(const Shape& input_shape,
const Shape& padding_below,
......
......@@ -27,13 +27,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
, m_cudnn_emitter(new CUDNNEmitter(this))
{
}
GPUPrimitiveEmitter::~GPUPrimitiveEmitter()
{
for (auto& primitive : m_gpu_primitives)
{
delete primitive;
}
}
std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter()
{
return m_cuda_emitter;
......@@ -42,9 +36,10 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{
return m_cudnn_emitter;
}
size_t GPUPrimitiveEmitter::insert(gpu::primitive* f)
size_t GPUPrimitiveEmitter::insert(std::unique_ptr<gpu::primitive>&& f)
{
m_gpu_primitives.push_back(f);
m_managed_primitives.emplace_back(std::move(f));
m_gpu_primitives.push_back(m_managed_primitives.back().get());
return m_gpu_primitives.size() - 1;
}
size_t GPUPrimitiveEmitter::lookup(std::string hash)
......
......@@ -35,11 +35,10 @@ namespace ngraph
{
public:
GPUPrimitiveEmitter();
~GPUPrimitiveEmitter();
std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::vector<gpu::primitive*>& get_primitives() { return m_gpu_primitives; }
size_t insert(gpu::primitive* f);
size_t insert(std::unique_ptr<gpu::primitive>&& f);
size_t lookup(std::string hash);
void cache(const std::string& hash, const size_t& index);
......@@ -48,6 +47,7 @@ namespace ngraph
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::vector<gpu::primitive*> m_gpu_primitives;
std::unordered_map<std::string, size_t> m_primitive_map;
std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives;
};
}
}
......
......@@ -68,6 +68,11 @@ void runtime::gpu::cuda_memcpyHtD(void* dst, void* src, size_t buffer_size)
cudaMemcpy(dst, src, buffer_size, cudaMemcpyHostToDevice);
}
void runtime::gpu::cuda_memcpyDtH(void* dst, void* src, size_t buffer_size)
{
cudaMemcpy(dst, src, buffer_size, cudaMemcpyDeviceToHost);
}
void runtime::gpu::cuda_memset(void* dst, int value, size_t buffer_size)
{
cudaMemset(dst, value, buffer_size);
......
......@@ -95,6 +95,7 @@ namespace ngraph
void free_gpu_buffer(void* buffer);
void cuda_memcpyDtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyHtD(void* dst, void* src, size_t buffer_size);
void cuda_memcpyDtH(void* dst, void* src, size_t buffer_size);
void cuda_memset(void* dst, int value, size_t buffer_size);
}
}
......
......@@ -25,6 +25,7 @@
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/serializer.hpp"
#include "util/all_close.hpp"
#include "util/ndarray.hpp"
......@@ -419,6 +420,7 @@ TEST(${BACKEND_NAME}, concat_vector)
TEST(${BACKEND_NAME}, concat_4d_tensor)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{1, 1, 1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -444,6 +446,7 @@ TEST(${BACKEND_NAME}, concat_4d_tensor)
TEST(${BACKEND_NAME}, concat_2d_tensor)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
......@@ -1349,6 +1352,7 @@ TEST(${BACKEND_NAME}, notequal)
TEST(${BACKEND_NAME}, select)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
......@@ -1766,6 +1770,7 @@ TEST(${BACKEND_NAME}, broadcast_matrix_2)
TEST(${BACKEND_NAME}, convert_int32_float32)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f =
......@@ -1784,6 +1789,7 @@ TEST(${BACKEND_NAME}, convert_int32_float32)
TEST(${BACKEND_NAME}, convert_int32_bool)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
......@@ -1802,6 +1808,7 @@ TEST(${BACKEND_NAME}, convert_int32_bool)
TEST(${BACKEND_NAME}, convert_float32_bool)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::Convert>(A, element::boolean),
......@@ -5156,6 +5163,7 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride
//
TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......@@ -5210,6 +5218,7 @@ TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
//
TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......@@ -5264,6 +5273,7 @@ TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
//
TEST(${BACKEND_NAME}, select_and_scatter_3d_without_overlap)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape_sel_a{};
auto SEL_A = make_shared<op::Parameter>(element::f32, shape_sel_a);
......@@ -7941,6 +7951,7 @@ TEST(${BACKEND_NAME}, validate_call_output_shape)
TEST(${BACKEND_NAME}, logical_and)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
......@@ -7961,6 +7972,7 @@ TEST(${BACKEND_NAME}, logical_and)
TEST(${BACKEND_NAME}, logical_or)
{
SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
Shape shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::boolean, shape);
auto B = make_shared<op::Parameter>(element::boolean, shape);
......@@ -7978,3 +7990,326 @@ TEST(${BACKEND_NAME}, logical_or)
backend->call(f, {result}, {a, b});
EXPECT_EQ((vector<char>{1, 0, 1, 1, 1, 1, 1, 0}), read_vector<char>(result));
}
TEST(${BACKEND_NAME}, batchnorm_fprop_b1c2h2w2)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
op::ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{1, 2, 2, 2});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{-0.71498716f,
1.48388731f,
-0.00196938f,
-0.76693159f,
-0.91316032f,
0.23943391f,
-0.84090298f,
1.51462936f};
vector<float> expected_mean{0.602912f, 0.599727f};
vector<float> expected_variance{0.00472505f, 0.0361782f};
backend->call(f, {bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output), 1e-5f, 1e-6f));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean), 1e-5f, 1e-6f));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
TEST(${BACKEND_NAME}, batchnorm_fprop_b2c2h2w1)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
op::ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
vector<float> expected_mean{0.583388f, 0.619252f};
vector<float> expected_variance{0.0119972f, 0.0282681f};
backend->call(f, {bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(
test::all_close(expected_variance, read_vector<float>(result_variance), 1e-5f, 1e-6f));
}
TEST(${BACKEND_NAME}, bn_bprop_n4c3h2w2)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
auto input_shape = Shape{4, 3, 2, 2};
auto shape_mean = Shape{3};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{3};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{3};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{3};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{3};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{4, 3, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto bn_dx = make_shared<op::GetOutputElement>(bn, 0);
auto bn_dgamma = make_shared<op::GetOutputElement>(bn, 1);
auto bn_dbeta = make_shared<op::GetOutputElement>(bn, 2);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto _input = backend->create_tensor(element::f32, input_shape);
vector<float> dataInput{
10.76331902f, 11.51178265f, 10.31018162f, 12.2993021f, 14.17626667f, 14.63498497f,
13.63494492f, 13.84248161f, 11.34602547f, 13.22014618f, 10.46686649f, 10.39842987f,
12.94806862f, 11.71670246f, 14.94438076f, 13.13236618f, 13.40889645f, 12.76128387f,
11.34430027f, 11.86629677f, 11.11464024f, 10.93221283f, 11.95324039f, 10.96581173f,
13.05455494f, 14.41404247f, 13.11169434f, 11.26559448f, 10.89965153f, 14.08202171f,
11.12685776f, 12.58428574f, 12.59247875f, 13.00187492f, 12.66310215f, 10.06655025f,
12.62048626f, 14.47942352f, 13.84950638f, 10.61425877f, 11.47936344f, 13.06011772f,
13.63069057f, 12.31748772f, 13.84555244f, 10.95815468f, 12.78933334f, 12.75389099f};
copy_data(_input, dataInput);
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{12.56472874f, 12.80312157f, 11.81676865f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{1.94557643f, 1.32772446f, 1.28163588f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{2.0f, 2.0f, 2.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{1.0f, 1.0f, 1.0f});
auto result = backend->create_tensor(element::f32, shape_r);
shared_ptr<runtime::TensorView> _delta = backend->create_tensor(element::f32, shape_r);
vector<float> deltaData(shape_size(shape_r), 20.0f);
copy_data(_delta, deltaData);
auto f = make_shared<Function>(NodeVector{bn_dx, bn_dgamma, bn_dbeta},
op::ParameterVector{mean, var, input, gamma, beta});
auto C = std::make_shared<op::Parameter>(element::f32, shape_r);
auto zero = ngraph::make_zero(bn_dgamma->get_element_type(), bn_dgamma->get_shape());
ngraph::autodiff::Adjoints adjoints(NodeVector{bn_dx, bn_dgamma, bn_dbeta},
NodeVector{C, zero, zero});
auto dinput = adjoints.backprop_node(input);
auto dgamma = adjoints.backprop_node(gamma);
auto dbeta = adjoints.backprop_node(beta);
auto df = make_shared<Function>(NodeVector{dinput, dgamma, dbeta},
op::ParameterVector{mean, var, input, gamma, beta, C});
//roundtrip serialization
string js = serialize(df, 4);
istringstream in(js);
df = deserialize(in);
shared_ptr<runtime::TensorView> _dinput = backend->create_tensor(element::f32, shape_r);
shared_ptr<runtime::TensorView> _dgamma = backend->create_tensor(element::f32, gamma_shape);
shared_ptr<runtime::TensorView> _dbeta = backend->create_tensor(element::f32, beta_shape);
backend->call(df, {_dinput, _dgamma, _dbeta}, {_mean, _var, _input, _gamma, _beta, _delta});
vector<float> expected_input{
8.17051607e-06f, 4.77576657e-06f, 1.02257760e-05f, 1.20387525e-06f, -1.73868522e-06f,
3.84632768e-06f, -1.07932050e-05f, -2.57458956e-06f, -2.22166714e-06f, -8.38779043e-06f,
-2.48082982e-06f, 5.89238360e-06f, -2.52895109e-07f, -8.68433445e-06f, -5.82726737e-06f,
8.84659658e-06f, 3.03944108e-05f, 4.05480879e-05f, 1.84123158e-05f, 2.30061178e-05f,
1.34087590e-05f, -9.26072571e-07f, -3.22908454e-05f, -2.07365116e-05f, -4.21330941e-05f,
2.83083100e-05f, -3.71039101e-05f, -4.84390640e-06f, -2.93012376e-05f, 5.68858087e-06f,
1.83181458e-05f, -1.07494506e-05f, -2.32429103e-06f, 6.92914809e-06f, -6.66512321e-06f,
-7.00302840e-06f, -3.46675184e-06f, -4.36748381e-06f, 6.73822226e-07f, -4.20158993e-06f,
3.83005061e-06f, 5.85143729e-06f, 4.17875243e-06f, -8.64167783e-06f, 1.00170803e-05f,
-4.23939666e-06f, 4.80201680e-06f, 4.62702078e-06f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dinput), expected_input, 1e-3f, 1e-4f));
vector<float> expected_dgamma{7.06315041e-05f, -2.35289335e-04f, -5.06639481e-05f};
ASSERT_TRUE(
ngraph::test::all_close(read_vector<float>(_dgamma), expected_dgamma, 1e-2f, 1e-3f));
vector<float> expected_dbeta{320.f, 320.f, 320.f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dbeta), expected_dbeta, 1e-4f, 1e-8f));
}
TEST(${BACKEND_NAME}, batchnorm_fprop_inference_b2c2h2w1)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
backend->call(f, {bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
TEST(${BACKEND_NAME}, batchnorm_fprop_globalstats_b2c2w2h1)
{
SKIP_TEST_FOR("IE", "${BACKEND_NAME}");
SKIP_TEST_FOR("NNP_TESTER", "${BACKEND_NAME}");
SKIP_TEST_FOR("INTERPRETER", "${BACKEND_NAME}");
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var, true);
auto f = make_shared<Function>(bn, op::ParameterVector{gamma, beta, input, mean, var});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, input_shape);
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
backend->call(f, {bn_output}, {_gamma, _beta, _input, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
......@@ -43,120 +43,6 @@
using namespace ngraph;
using namespace std;
//TODO: Move this test to backend_test.in.cpp once we have the INTERPRETER
// implementation for batchnorm
TEST(cpu_test, batchnorm_fprop_b1c2h2w2)
{
auto input_shape = Shape{1, 2, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{1, 2, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
op::ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{1, 2, 2, 2});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{-0.71498716f,
1.48388731f,
-0.00196938f,
-0.76693159f,
-0.91316032f,
0.23943391f,
-0.84090298f,
1.51462936f};
vector<float> expected_mean{0.602912f, 0.599727f};
vector<float> expected_variance{0.00472505f, 0.0361782f};
backend->call(f, {bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(test::all_close(expected_variance, read_vector<float>(result_variance)));
}
TEST(cpu_test, batchnorm_fprop_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto var_shape = Shape{2};
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto output_rt = std::make_shared<op::GetOutputElement>(bn, 0);
auto mean_rt = std::make_shared<op::GetOutputElement>(bn, 1);
auto variance_rt = std::make_shared<op::GetOutputElement>(bn, 2);
auto f = make_shared<Function>(NodeVector{output_rt, mean_rt, variance_rt},
op::ParameterVector{input, gamma, beta});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
vector<float> expected_mean{0.583388f, 0.619252f};
vector<float> expected_variance{0.0119972f, 0.0282681f};
backend->call(f, {bn_output, result_mean, result_variance}, {_input, _gamma, _beta});
EXPECT_TRUE(test::all_close(expected_result, read_vector<float>(bn_output)));
EXPECT_TRUE(test::all_close(expected_mean, read_vector<float>(result_mean)));
EXPECT_TRUE(test::all_close(expected_variance, read_vector<float>(result_variance)));
}
class UnhandledOp : public ngraph::op::Abs
{
public:
......@@ -174,196 +60,3 @@ TEST(cpu_test, unhandled_op)
auto backend = runtime::Backend::create("CPU");
ASSERT_THROW(backend->compile(f), ngraph_error);
}
TEST(cpu_test, bn_bprop_n4c3h2w2)
{
auto input_shape = Shape{4, 3, 2, 2};
auto shape_mean = Shape{3};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{3};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{3};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{3};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{3};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{4, 3, 2, 2};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input);
auto bn_dx = make_shared<op::GetOutputElement>(bn, 0);
auto bn_dgamma = make_shared<op::GetOutputElement>(bn, 1);
auto bn_dbeta = make_shared<op::GetOutputElement>(bn, 2);
auto backend = runtime::Backend::create("CPU");
auto _input = backend->create_tensor(element::f32, input_shape);
vector<float> dataInput{
10.76331902f, 11.51178265f, 10.31018162f, 12.2993021f, 14.17626667f, 14.63498497f,
13.63494492f, 13.84248161f, 11.34602547f, 13.22014618f, 10.46686649f, 10.39842987f,
12.94806862f, 11.71670246f, 14.94438076f, 13.13236618f, 13.40889645f, 12.76128387f,
11.34430027f, 11.86629677f, 11.11464024f, 10.93221283f, 11.95324039f, 10.96581173f,
13.05455494f, 14.41404247f, 13.11169434f, 11.26559448f, 10.89965153f, 14.08202171f,
11.12685776f, 12.58428574f, 12.59247875f, 13.00187492f, 12.66310215f, 10.06655025f,
12.62048626f, 14.47942352f, 13.84950638f, 10.61425877f, 11.47936344f, 13.06011772f,
13.63069057f, 12.31748772f, 13.84555244f, 10.95815468f, 12.78933334f, 12.75389099f};
copy_data(_input, dataInput);
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{12.56472874f, 12.80312157f, 11.81676865f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{1.94557643f, 1.32772446f, 1.28163588f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{2.0f, 2.0f, 2.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{1.0f, 1.0f, 1.0f});
auto result = backend->create_tensor(element::f32, shape_r);
shared_ptr<runtime::TensorView> _delta = backend->create_tensor(element::f32, shape_r);
vector<float> deltaData(shape_size(shape_r), 20.0f);
copy_data(_delta, deltaData);
auto f = make_shared<Function>(NodeVector{bn_dx, bn_dgamma, bn_dbeta},
op::ParameterVector{mean, var, input, gamma, beta});
auto C = std::make_shared<op::Parameter>(element::f32, shape_r);
auto zero = ngraph::make_zero(bn_dgamma->get_element_type(), bn_dgamma->get_shape());
ngraph::autodiff::Adjoints adjoints(NodeVector{bn_dx, bn_dgamma, bn_dbeta},
NodeVector{C, zero, zero});
auto dinput = adjoints.backprop_node(input);
auto dgamma = adjoints.backprop_node(gamma);
auto dbeta = adjoints.backprop_node(beta);
auto df = make_shared<Function>(NodeVector{dinput, dgamma, dbeta},
op::ParameterVector{mean, var, input, gamma, beta, C});
//roundtrip serialization
string js = serialize(df, 4);
istringstream in(js);
df = deserialize(in);
shared_ptr<runtime::TensorView> _dinput = backend->create_tensor(element::f32, shape_r);
shared_ptr<runtime::TensorView> _dgamma = backend->create_tensor(element::f32, gamma_shape);
shared_ptr<runtime::TensorView> _dbeta = backend->create_tensor(element::f32, beta_shape);
backend->call(df, {_dinput, _dgamma, _dbeta}, {_mean, _var, _input, _gamma, _beta, _delta});
vector<float> expected_input{
8.17051607e-06f, 4.77576657e-06f, 1.02257760e-05f, 1.20387525e-06f, -1.73868522e-06f,
3.84632768e-06f, -1.07932050e-05f, -2.57458956e-06f, -2.22166714e-06f, -8.38779043e-06f,
-2.48082982e-06f, 5.89238360e-06f, -2.52895109e-07f, -8.68433445e-06f, -5.82726737e-06f,
8.84659658e-06f, 3.03944108e-05f, 4.05480879e-05f, 1.84123158e-05f, 2.30061178e-05f,
1.34087590e-05f, -9.26072571e-07f, -3.22908454e-05f, -2.07365116e-05f, -4.21330941e-05f,
2.83083100e-05f, -3.71039101e-05f, -4.84390640e-06f, -2.93012376e-05f, 5.68858087e-06f,
1.83181458e-05f, -1.07494506e-05f, -2.32429103e-06f, 6.92914809e-06f, -6.66512321e-06f,
-7.00302840e-06f, -3.46675184e-06f, -4.36748381e-06f, 6.73822226e-07f, -4.20158993e-06f,
3.83005061e-06f, 5.85143729e-06f, 4.17875243e-06f, -8.64167783e-06f, 1.00170803e-05f,
-4.23939666e-06f, 4.80201680e-06f, 4.62702078e-06f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dinput), expected_input, 1e-3f, 1e-4f));
vector<float> expected_dgamma{7.06315041e-05f, -2.35289335e-04f, -5.06639481e-05f};
ASSERT_TRUE(
ngraph::test::all_close(read_vector<float>(_dgamma), expected_dgamma, 1e-2f, 1e-3f));
vector<float> expected_dbeta{320.f, 320.f, 320.f};
ASSERT_TRUE(ngraph::test::all_close(read_vector<float>(_dbeta), expected_dbeta, 1e-4f, 1e-8f));
}
TEST(cpu_test, batchnorm_fprop_inference_b2c2h2w1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var);
auto f = make_shared<Function>(bn, op::ParameterVector{input, gamma, beta, mean, var});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
backend->call(f, {bn_output}, {_input, _gamma, _beta, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
TEST(cpu_test, batchnorm_fprop_globalstats_b2c2w2h1)
{
auto input_shape = Shape{2, 2, 2, 1};
auto input = make_shared<op::Parameter>(element::f32, input_shape);
auto mean_shape = Shape{2};
auto mean = make_shared<op::Parameter>(element::f32, mean_shape);
auto var_shape = Shape{2};
auto var = make_shared<op::Parameter>(element::f32, var_shape);
auto gamma_shape = Shape{2};
auto gamma = make_shared<op::Parameter>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<op::Parameter>(element::f32, beta_shape);
double eps = 0.001;
auto shape_r = Shape{2, 2, 2, 1};
auto bn = make_shared<op::BatchNorm>(eps, gamma, beta, input, mean, var, true);
auto f = make_shared<Function>(bn, op::ParameterVector{gamma, beta, input, mean, var});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
auto _input = backend->create_tensor(element::f32, Shape{2, 2, 2, 1});
copy_data(_input,
vector<float>{0.54881352f,
0.71518934f,
0.60276335f,
0.54488319f,
0.42365479f,
0.64589411f,
0.4375872f,
0.89177299f});
auto _gamma = backend->create_tensor(element::f32, gamma_shape);
copy_data(_gamma, vector<float>{1.0f, 1.0f});
auto _beta = backend->create_tensor(element::f32, beta_shape);
copy_data(_beta, vector<float>{0.0f, 0.0f});
auto _mean = backend->create_tensor(element::f32, mean_shape);
copy_data(_mean, vector<float>{0.583388f, 0.619252f});
auto _var = backend->create_tensor(element::f32, var_shape);
copy_data(_var, vector<float>{0.0119972f, 0.0282681f});
auto bn_output = backend->create_tensor(element::f32, shape_r);
auto result_mean = backend->create_tensor(element::f32, mean_shape);
auto result_variance = backend->create_tensor(element::f32, var_shape);
vector<float> expected_result{
-0.30327f, 1.1561f, -0.0963782f, -0.434702f, -1.4011f, 0.548275f, -1.06187f, 1.59295f};
backend->call(f, {bn_output}, {_gamma, _beta, _input, _mean, _var});
ASSERT_TRUE(
ngraph::test::all_close(expected_result, read_vector<float>(bn_output), 1e-3f, 1e-4f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment