Commit 84de3bf4 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

nvgpu optimize reshape v3 (#1617)

* pass args instead of pointer to array

* add 3d tiled reshpae

* working version

* add shared mem version of 2d, 3d reshape

* remove unused code

* style

* resolve commits

* add test for 3D reshape, some 3D reshape will be treat as 2D
parent 00afd349
...@@ -567,6 +567,7 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string ...@@ -567,6 +567,7 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string
m_primitive_emitter->cache(hash, primitive_index); m_primitive_emitter->cache(hash, primitive_index);
return primitive_index; return primitive_index;
} }
size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>& dtypes, size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape input_order) NVShape input_order)
...@@ -653,6 +654,186 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2> ...@@ -653,6 +654,186 @@ size_t runtime::gpu::CUDAEmitter::build_reshape(const std::array<std::string, 2>
return primitive_index; return primitive_index;
} }
size_t runtime::gpu::CUDAEmitter::build_reshape_2d(const std::array<std::string, 2>& dtypes,
NVShape input_shape,
NVShape input_order)
{
auto rank = input_shape.size();
std::stringstream kernel_name;
kernel_name << "reshape_" << join(dtypes, "_");
std::string hash =
kernel_name.str() + "_i_" + join(input_shape, "_") + "_o_" + join(input_order, "_");
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// TODO: currently we set it to 16, will add tuning method later
uint32_t block_size = 16;
uint32_t aligned_grid_size_x = align_to_block_size(input_shape[1], block_size);
uint32_t aligned_grid_size_y = align_to_block_size(input_shape[0], block_size);
NVShape input_strides = row_major_strides(input_shape);
NVShape output_strides(rank);
NVShape trans_strides(rank);
int stride = 1;
for (int64_t i = rank - 1; i >= 0; i--)
{
output_strides[i] = stride;
stride *= input_shape[input_order[i]];
}
for (int64_t i = 0; i < rank; i++)
{
trans_strides[input_order[i]] = output_strides[i];
}
// get an allocator for transient per kernel gpu memory
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("input_strides", input_strides)
.add("trans_strides", trans_strides)
.add("nx", input_shape[1])
.add("ny", input_shape[0]);
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reshape_op_2d(
writer, kernel_name.str(), args, dtypes[1], block_size);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
aligned_grid_size_y,
1, // grid dim
block_size,
block_size,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_reshape_3d(const std::array<std::string, 2>& dtypes,
NVShape input_shape,
NVShape input_order)
{
auto rank = input_shape.size();
std::stringstream kernel_name;
kernel_name << "reshape_" << join(dtypes, "_") << "_r_" << join(input_order, "_");
std::string hash = kernel_name.str() + "_i_" + join(input_shape, "_");
// For backwards compatability we currently use two unordered maps
// 1. one looks up the compiled cuda kernel (CudaFunctionPool)
// 2. the other looks to see if this kernel is already in the primitive list
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::vector<uint32_t> block_size(3, 0);
// TODO: currently we set it to 16, will add tuning method later
uint32_t block_size_x = 16;
block_size[0] = block_size_x; //x
block_size[2] = (input_order[2] == 0) ? block_size_x : 1; //z
block_size[1] = (block_size[2] == block_size_x) ? 1 : block_size_x; //y
uint32_t aligned_grid_size_x = align_to_block_size(input_shape[2], block_size[0]);
uint32_t aligned_grid_size_y = align_to_block_size(input_shape[1], block_size[1]);
uint32_t aligned_grid_size_z = align_to_block_size(input_shape[0], block_size[2]);
NVShape input_strides = row_major_strides(input_shape);
NVShape output_strides(rank);
NVShape trans_strides(rank);
int stride = 1;
for (int64_t i = rank - 1; i >= 0; i--)
{
output_strides[i] = stride;
stride *= input_shape[input_order[i]];
}
for (int64_t i = 0; i < rank; i++)
{
trans_strides[input_order[i]] = output_strides[i];
}
// get an allocator for transient per kernel gpu memory
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("input_strides", input_strides)
.add("trans_strides", trans_strides)
.add("nx", input_shape[2])
.add("ny", input_shape[1])
.add("nz", input_shape[0]);
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_reshape_op_3d(
writer, kernel_name.str(), args, dtypes[1], input_order, block_size);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// create the launch primitive
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
aligned_grid_size_y,
aligned_grid_size_z, // grid dim
block_size[0],
block_size[1],
block_size[2], // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_slice(const std::array<std::string, 2>& dtypes, size_t runtime::gpu::CUDAEmitter::build_slice(const std::array<std::string, 2>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape lower_bounds, NVShape lower_bounds,
......
...@@ -154,6 +154,14 @@ namespace ngraph ...@@ -154,6 +154,14 @@ namespace ngraph
NVShape input_shape, NVShape input_shape,
NVShape input_order); NVShape input_order);
size_t build_reshape_2d(const std::array<std::string, 2>& dtypes,
NVShape input_shape,
NVShape input_order);
size_t build_reshape_3d(const std::array<std::string, 2>& dtypes,
NVShape input_shape,
NVShape input_order);
size_t build_convolution(const std::array<std::string, 3>& dtypes, size_t build_convolution(const std::array<std::string, 3>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape filter_shape, NVShape filter_shape,
......
...@@ -481,6 +481,122 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer ...@@ -481,6 +481,122 @@ void runtime::gpu::CudaKernelBuilder::get_reshape_op(codegen::CodeWriter& writer
writer.block_end(); writer.block_end();
} }
void runtime::gpu::CudaKernelBuilder::get_reshape_op_2d(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::string& data_type,
uint32_t block_size)
{
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "__shared__ " << data_type << " tile[" << block_size << "][" << block_size + 1
<< "];\n";
writer << "uint32_t base1 = blockIdx.x * blockDim.x;\n";
writer << "uint32_t base0 = blockIdx.y * blockDim.y;\n";
writer << "uint32_t tid1 = threadIdx.x;\n";
writer << "uint32_t tid0 = threadIdx.y;\n";
writer << "uint32_t idx1 = base1 + tid1;\n";
writer << "uint32_t idx0 = base0 + tid0;\n";
writer << "if (idx1 < nx && idx0 < ny)\n";
writer.block_begin();
{
writer << "uint32_t input_idx = 0;\n";
for (int i = 0; i < 2; i++)
{
writer << "input_idx += input_strides" << i << "* idx" << i << ";\n";
}
writer << "tile[tid0][tid1] = in[input_idx];\n";
}
writer.block_end();
writer << "idx1 = base1 + tid0;\n";
writer << "idx0 = base0 + tid1;\n";
writer << "__syncthreads();\n";
writer << "if (idx1 < nx && idx0 < ny)\n";
writer.block_begin();
{
writer << "uint32_t output_idx = 0;\n";
for (int i = 0; i < 2; i++)
{
writer << "output_idx += trans_strides" << i << "* idx" << i << ";\n";
}
writer << "out[output_idx] = tile[tid1][tid0];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_reshape_op_3d(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::string& data_type,
const std::vector<uint32_t>& order,
const std::vector<uint32_t>& block_size)
{
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "__shared__ " << data_type << " tile[" << block_size[2] << "][" << block_size[1]
<< "][" << block_size[0] + 1 << "];\n";
writer << "uint32_t base2 = blockIdx.x * blockDim.x;\n";
writer << "uint32_t base1 = blockIdx.y * blockDim.y;\n";
writer << "uint32_t base0 = blockIdx.z * blockDim.z;\n";
writer << "uint32_t tid2 = threadIdx.x;\n";
writer << "uint32_t tid1 = threadIdx.y;\n";
writer << "uint32_t tid0 = threadIdx.z;\n";
writer << "uint32_t otid2 = tid2;\n";
writer << "uint32_t otid1 = tid1;\n";
writer << "uint32_t otid0 = tid0;\n";
writer << "uint32_t idx2 = base2 + tid2;\n";
writer << "uint32_t idx1 = base1 + tid1;\n";
writer << "uint32_t idx0 = base0 + tid0;\n";
writer << "if (idx2 < nx && idx1 < ny && idx0 < nz)\n";
writer.block_begin();
{
writer << "uint32_t input_idx = 0;\n";
for (int i = 0; i < 3; i++)
{
writer << "input_idx += input_strides" << i << "* idx" << i << ";\n";
}
writer << "tile[tid0][tid1][tid2] = in[input_idx];\n";
}
writer.block_end();
if (order[2] == 1)
{
writer << "otid2 = tid1;\n";
writer << "otid1 = tid2;\n";
}
else if (order[2] == 0)
{
writer << "otid2 = tid0;\n";
writer << "otid0 = tid2;\n";
}
writer << "idx2 = base2 + otid2;\n";
writer << "idx1 = base1 + otid1;\n";
writer << "idx0 = base0 + otid0;\n";
writer << "__syncthreads();\n";
writer << "if (idx2 < nx && idx1 < ny && idx0 < nz)\n";
writer.block_begin();
{
writer << "uint32_t output_idx = 0;\n";
for (int i = 0; i < 3; i++)
{
writer << "output_idx += trans_strides" << i << "* idx" << i << ";\n";
}
writer << "out[output_idx] = tile[otid0][otid1][otid2];\n";
}
writer.block_end();
}
writer.block_end();
}
void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer, void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::vector<std::string>& data_types, const std::vector<std::string>& data_types,
......
...@@ -60,6 +60,19 @@ namespace ngraph ...@@ -60,6 +60,19 @@ namespace ngraph
const std::array<std::string, 2>& data_types, const std::array<std::string, 2>& data_types,
size_t rank); size_t rank);
static void get_reshape_op_3d(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::string& data_type,
const std::vector<uint32_t>& order,
const std::vector<uint32_t>& block_size);
static void get_reshape_op_2d(codegen::CodeWriter& writer,
const std::string& name,
runtime::gpu::GPUKernelArgs& args,
const std::string& data_type,
uint32_t block_size);
static void get_reduce_to_nd_op(codegen::CodeWriter& writer, static void get_reduce_to_nd_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
runtime::gpu::GPUKernelArgs& args, runtime::gpu::GPUKernelArgs& args,
......
...@@ -533,50 +533,129 @@ namespace ngraph ...@@ -533,50 +533,129 @@ namespace ngraph
return; return;
} }
writer.block_begin();
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size(); auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto input_order = reshape->get_input_order(); auto input_order = reshape->get_input_order();
size_t result_shape_product = shape_size(result_shape); size_t result_shape_product = shape_size(result_shape);
// If there is no layout change or we are just going from 1^n to 1^m or a zero-size tensor, //for a zero-size tensor, or change from 1^m shape to 1^n shape, just do a copy
// we can just copy.
if (!reshape->get_is_transpose() || result_shape_product < 2) if (!reshape->get_is_transpose() || result_shape_product < 2)
{ {
kernel::emit_memcpyDtD(writer, out[0], args[0]); writer.block_begin();
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
writer.block_end();
return;
} }
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2) //combine inordered dimensons after reorder in shape, update output shape and input order
Shape in_order_map(arg_rank, 0);
for (int i = 0; i < arg_rank - 1; i++)
{ {
// TODO Assert arg0_shape[0] == arg1_shape[0]? if (static_cast<int64_t>(input_order[i + 1]) -
writer << "const float alpha = 1.0;\n"; static_cast<int64_t>(input_order[i]) ==
writer << "const float beta = 0;\n"; 1)
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, " {
"CUBLAS_POINTER_MODE_HOST));\n"; in_order_map[input_order[i]] = 1;
writer << "CUBLAS_SAFE_CALL(cublasSgeam(" }
<< "*ctx->cublas_handle,"
<< "CUBLAS_OP_T,"
<< "CUBLAS_OP_T," << arg_shape[0] << "," << arg_shape[1] << ","
<< "&alpha," // Alpha
<< args[0].get_name() << "," << arg_shape[1] << ","
<< "&beta," // beta
<< args[0].get_name() << "," << arg_shape[1] << "," << out[0].get_name()
<< "," << result_shape[1] << "));\n";
writer << "CUBLAS_SAFE_CALL(cublasSetPointerMode(*ctx->cublas_handle, "
"CUBLAS_POINTER_MODE_DEVICE));\n";
} }
// Other cases (reordering of axes for tensors with rank>2).
else Shape combine_arg_shape;
Shape combine_idx_map(arg_rank, 0);
Shape combine_input_order;
size_t shape_i = 1;
size_t combine_rank = 0;
for (int i = 0; i < arg_rank; i++)
{ {
auto& cuda_emitter = if (in_order_map[i] == 1)
external_function->get_primitive_emitter()->get_cuda_emitter(); {
auto index = cuda_emitter->build_reshape( shape_i *= arg_shape[i];
{{args[0].get_type(), out[0].get_type()}}, arg_shape, input_order); }
else
{
combine_arg_shape.push_back(shape_i * arg_shape[i]);
shape_i = 1;
combine_idx_map[i] = combine_rank++;
}
}
writer << "void* input[] = {" << node_names(args) << "};\n"; for (int i = 0; i < arg_rank; i++)
writer << "void* output[] = {" << node_names(out) << "};\n"; {
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n"; if (in_order_map[input_order[i]] == 0)
{
combine_input_order.push_back(combine_idx_map[input_order[i]]);
}
}
//eleminate dimenson size = 1, update input order and output shape
Shape new_arg_shape;
Shape new_result_shape;
Shape new_idx_map(combine_rank, 0);
Shape new_input_order;
size_t new_rank = 0;
for (int i = 0; i < combine_rank; i++)
{
if (combine_arg_shape[i] != 1)
{
new_arg_shape.push_back(combine_arg_shape[i]);
new_idx_map[i] = new_rank++;
}
}
for (int i = 0; i < combine_rank; i++)
{
if (combine_arg_shape[combine_input_order[i]] != 1)
{
new_input_order.push_back(new_idx_map[combine_input_order[i]]);
}
}
for (int i = 0; i < new_rank; i++)
{
new_result_shape.push_back(new_arg_shape[new_input_order[i]]);
}
// If there is no layout change, we can just copy.
writer.block_begin();
{
bool same_layout = is_sorted(new_input_order.begin(), new_input_order.end());
if (same_layout)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
// If there *is* a layout change in the 2D case, we transpose the input.
else
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index;
if (new_rank == 2)
{
index = cuda_emitter->build_reshape_2d(
{{args[0].get_type(), out[0].get_type()}},
new_arg_shape,
new_input_order);
}
// If there *is* a layout change in the 3D case, we do 3D tiled reshape.
else if (new_rank == 3)
{
index = cuda_emitter->build_reshape_3d(
{{args[0].get_type(), out[0].get_type()}},
new_arg_shape,
new_input_order);
}
// Other cases (reordering of axes for tensors with rank>3).
else
{
index = cuda_emitter->build_reshape(
{{args[0].get_type(), out[0].get_type()}},
new_arg_shape,
new_input_order);
}
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
} }
writer.block_end(); writer.block_end();
} }
......
...@@ -2851,16 +2851,16 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_dim_change_transpose) ...@@ -2851,16 +2851,16 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_dim_change_transpose)
EXPECT_EQ((vector<float>{1, 3, 5, 2, 4, 6}), read_vector<float>(result)); EXPECT_EQ((vector<float>{1, 3, 5, 2, 4, 6}), read_vector<float>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose) NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose_021)
{ {
vector<float> a_data(2 * 2 * 5); vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 2 * 5; i++) for (int i = 0; i < 2 * 3 * 4; i++)
{ {
a_data[i] = float(i + 1); a_data[i] = float(i + 1);
} }
Shape shape_a{2, 2, 5}; Shape shape_a{2, 3, 4};
Shape shape_r{2, 5, 2}; Shape shape_r{2, 4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a); auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 2, 1}, shape_r); auto r = make_shared<op::Reshape>(A, AxisVector{0, 2, 1}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A}); auto f = make_shared<Function>(r, op::ParameterVector{A});
...@@ -2873,8 +2873,116 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose) ...@@ -2873,8 +2873,116 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose)
auto result = backend->create_tensor(element::f32, shape_r); auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a}); backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1., 6., 2., 7., 3., 8., 4., 9., 5., 10., EXPECT_EQ((vector<float>{1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12,
11., 16., 12., 17., 13., 18., 14., 19., 15., 20.}), 13, 17, 21, 14, 18, 22, 15, 19, 23, 16, 20, 24}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose_210)
{
vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 3, 4};
Shape shape_r{4, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{2, 1, 0}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22,
3, 15, 7, 19, 11, 23, 4, 16, 8, 20, 12, 24}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose_201)
{
vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 3, 4};
Shape shape_r{4, 2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{2, 0, 1}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22,
3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose_102)
{
vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 3, 4};
Shape shape_r{3, 2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{1, 0, 2}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1, 2, 3, 4, 13, 14, 15, 16, 5, 6, 7, 8,
17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose_120)
{
vector<float> a_data(2 * 3 * 4);
for (int i = 0; i < 2 * 3 * 4; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 3, 4};
Shape shape_r{3, 4, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{1, 2, 0}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18,
7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24}),
read_vector<float>(result)); read_vector<float>(result));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment