Commit d901446d authored by Ayan Moitra's avatar Ayan Moitra Committed by Robert Kimball

Support TopK for NvidiaGPU backend (#1908)

* fresh commit for the changes

* Working topk on ndims for nvGPU

* fix

* clang

* Added unit test, improved kernel hash and Bob's comment

* int64 test+clang

* Moved argReduce and topk tests to a separate file

* TopK unsupported for IntelGPU

* addressed Fenglei and Chris's comments

* addressed Fenglei and Chris's comments
parent 239322e0
...@@ -212,6 +212,138 @@ size_t runtime::gpu::CUDAEmitter::build_concat(const std::string& dtype, ...@@ -212,6 +212,138 @@ size_t runtime::gpu::CUDAEmitter::build_concat(const std::string& dtype,
return this->m_primitive_emitter->register_primitive(kernel_launch, hash.str()); return this->m_primitive_emitter->register_primitive(kernel_launch, hash.str());
} }
size_t runtime::gpu::CUDAEmitter::build_topk(const std::vector<element::Type>& dtypes,
const NVShape& input_shape,
const size_t topk_axis,
size_t topk_k,
const element::Type index_elem_type,
bool compute_max)
{
NGRAPH_ASSERT(dtypes[1] == index_elem_type)
<< " The index element type does not match out[0] type";
uint32_t rank = static_cast<uint32_t>(input_shape.size());
NGRAPH_ASSERT(rank <= 2) << " The input tensor should be of either rank 1 or rank 2";
NGRAPH_ASSERT(topk_axis == rank - 1)
<< " The axis along which topk is computed should be the last axis";
size_t num_cols = input_shape[rank - 1];
size_t num_rows = ((rank == 2) ? input_shape[0] : 1);
std::vector<std::string> dtypes_string;
for (auto& dtype : dtypes)
{
dtypes_string.push_back(dtype.c_type_string());
}
/* The struct 'Entry' used in the kernel looks like this:
struct Entry
{
size_t index;
float value;
__device__ size_t get_index(){return index;}
__device__ void set_index(size_t id) {index = id;}
__device__ float get_value(){return value;}
__device__ void set_value(float val){value = val;}
};
Based on the datatypes, the max size of the struct can be 16 bytes. Any arbitrary size of the struct can
therfore be given by 'shared_struct_bytes' as calculated below accounting for structure padding*/
size_t shared_struct_bytes = (((dtypes[0].size() + index_elem_type.size()) <= 8) ? 8 : 16);
size_t shared_data_bytes = num_cols * shared_struct_bytes;
// Use global memory when each row size exceeds shared mem allowed per block
int device_num = 0;
CUDA_RT_SAFE_CALL(cudaGetDevice(&device_num));
cudaDeviceProp prop;
CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device_num));
bool use_malloc = ((shared_data_bytes > prop.sharedMemPerBlock) ? true : false);
std::stringstream kernel_name;
kernel_name << "topk_" << join(dtypes_string, "_") << "_cm_" << compute_max << "_use_malloc_"
<< use_malloc;
std::string hash = kernel_name.str() + "_i_" + join(input_shape, "_");
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
uint32_t block_size_x = 32;
uint32_t aligned_grid_size_x = num_rows;
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes_string[0], "in")
.add_placeholder(dtypes_string[1], "out_id")
.add_placeholder(dtypes_string[2], "out_val");
if (use_malloc)
{
args.add_placeholder("Entry", "entry");
}
args.add("num_cols", num_cols).add("topk_k", topk_k);
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
runtime::gpu::CudaKernelBuilder::get_topk(
writer, kernel_name.str(), dtypes_string, compute_max, args, use_malloc);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
if (use_malloc)
{
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t heap_workspace_id = allocator.reserve_workspace(num_rows * shared_data_bytes);
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* buffer = runtime::gpu::invoke_memory_primitive(m_ctx, heap_workspace_id);
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.resolve_placeholder(2, &outputs[1])
.resolve_placeholder(3, &buffer)
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
0,
NULL, // stream
args_list,
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
else
{
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.resolve_placeholder(2, &outputs[1])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1,
block_size_x,
1,
1,
shared_data_bytes, // shared mem
NULL, //stream
args_list,
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_onehot(const std::array<std::string, 2>& dtypes, size_t runtime::gpu::CUDAEmitter::build_onehot(const std::array<std::string, 2>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape output_shape, NVShape output_shape,
...@@ -2165,7 +2297,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const OpName op_name, ...@@ -2165,7 +2297,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const OpName op_name,
args_list.data(), args_list.data(),
0)); // arguments 0)); // arguments
debug_sync(); debug_sync();
}}); }});
return this->m_primitive_emitter->register_primitive(f, hash); return this->m_primitive_emitter->register_primitive(f, hash);
...@@ -2656,7 +2787,6 @@ size_t runtime::gpu::CUDAEmitter::build_convolution(const std::array<std::string ...@@ -2656,7 +2787,6 @@ size_t runtime::gpu::CUDAEmitter::build_convolution(const std::array<std::string
std::unique_ptr<gpu::primitive> conv( std::unique_ptr<gpu::primitive> conv(
new gpu::primitive{[=](void** inputs, void** outputs) mutable { new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0]) void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &inputs[1]) .resolve_placeholder(1, &inputs[1])
.resolve_placeholder(2, &outputs[0]) .resolve_placeholder(2, &outputs[0])
......
...@@ -50,6 +50,13 @@ namespace ngraph ...@@ -50,6 +50,13 @@ namespace ngraph
size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op); size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op);
public: public:
size_t build_topk(const std::vector<element::Type>& dtypes,
const NVShape& input_shape,
const size_t topk_axis,
size_t topk_k,
const element::Type index_elem_type,
bool compute_max);
size_t build_pad(const std::vector<std::string>& dtypes, size_t build_pad(const std::vector<std::string>& dtypes,
NVShape input_shape, NVShape input_shape,
NVShape output_shape, NVShape output_shape,
......
...@@ -202,6 +202,130 @@ void runtime::gpu::CudaKernelBuilder::get_ew_collective_op( ...@@ -202,6 +202,130 @@ void runtime::gpu::CudaKernelBuilder::get_ew_collective_op(
return; return;
} }
void runtime::gpu::CudaKernelBuilder::get_topk(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args,
bool use_malloc)
{
writer << "struct Entry\n";
writer.block_begin();
{
writer << dtypes[0] << " value;\n";
writer << dtypes[1] << " index;\n";
writer << "__device__ " << dtypes[1] << " get_index() {return index;}\n";
writer << "__device__ "
<< "void set_index(" << dtypes[1] << " id) {index = id;}\n";
writer << "__device__ " << dtypes[0] << " get_value() {return value;}\n";
writer << "__device__ "
<< "void set_value(" << dtypes[0] << " val) {value = val;}\n";
}
writer.block_end();
writer << ";\n";
writer << "__device__ void swap(Entry& a, Entry& b)\n";
writer.block_begin();
{
writer << "Entry t = a;\n";
writer << "a = b;\n";
writer << "b = t;\n";
}
writer.block_end();
writer << "__device__ void heapify(Entry *heap, size_t heap_size, size_t idx)\n";
writer.block_begin();
{
writer << "size_t largest = idx;\n";
writer << "size_t left = (idx << 1) + 1;\n";
writer << "size_t right = (idx + 1) << 1;\n";
std::string g_op = ((compute_max) ? ">" : "<");
writer << "if (left < heap_size && heap[left].get_value() " << g_op
<< " heap[largest].get_value())\n";
writer.block_begin();
{
writer << "largest = left;\n";
}
writer.block_end();
writer << "if (right < heap_size && heap[right].get_value() " << g_op
<< " heap[largest].get_value())\n";
writer.block_begin();
{
writer << "largest = right;\n";
}
writer.block_end();
writer << "if (largest != idx)\n";
writer.block_begin();
{
writer << "swap(heap[largest], heap[idx]);\n";
writer << "heapify(heap, heap_size, largest);\n";
}
writer.block_end();
}
writer.block_end();
writer << "__device__ void create_and_build(Entry *entry, size_t size)\n";
writer.block_begin();
{
writer << "for (int i = (size-2) / 2; i >= 0; --i)\n";
writer.block_begin();
{
writer << "heapify(entry, size, i);\n";
}
writer.block_end();
}
writer.block_end();
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "in = in + blockIdx.x * num_cols;\n";
if (use_malloc)
{
writer << "entry = entry + blockIdx.x * num_cols;\n";
}
writer << "out_id = out_id + blockIdx.x * topk_k;\n";
writer << "out_val = out_val + blockIdx.x * topk_k;\n";
if (!use_malloc)
{
writer << "extern __shared__ Entry entry[];\n";
}
writer << "for (size_t i = threadIdx.x; i < num_cols; i += blockDim.x)\n";
writer.block_begin();
{
writer << "entry[i].set_value(in[i]);\n";
writer << "entry[i].set_index(i);\n";
}
writer.block_end();
writer << "__syncthreads();\n";
writer << "if (threadIdx.x == 0)\n";
writer.block_begin();
{
writer << "create_and_build(entry, num_cols);\n";
writer << "size_t changed_size_of_heap = num_cols;\n";
writer << "size_t k = 0;\n";
writer << "while (k++ < topk_k)\n";
writer.block_begin();
{
writer << "swap(*entry, entry[changed_size_of_heap - 1]);\n";
writer << "heapify(entry, --changed_size_of_heap, 0);\n";
}
writer.block_end();
writer << "for (size_t i = threadIdx.x; i < topk_k; i++)\n";
writer.block_begin();
{
writer << "out_val[i] = entry[num_cols - 1 - i].get_value();\n";
writer << "out_id[i] = entry[num_cols - 1 - i].get_index();\n";
}
writer.block_end();
}
writer.block_end();
}
writer.block_end();
}
//each thread calculate the whole reduction of one output //each thread calculate the whole reduction of one output
void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op( void runtime::gpu::CudaKernelBuilder::get_reduce_to_nd_op(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
......
...@@ -85,6 +85,13 @@ namespace ngraph ...@@ -85,6 +85,13 @@ namespace ngraph
size_t out_rank, size_t out_rank,
size_t reduce_rank); size_t reduce_rank);
static void get_topk(codegen::CodeWriter& writer,
const std::string& name,
const std::vector<std::string>& dtypes,
bool compute_max,
runtime::gpu::GPUKernelArgs& args,
bool use_malloc);
//using one block with at most 512 threads to reduce to scalar. //using one block with at most 512 threads to reduce to scalar.
static void get_reduce_to_scalar_op(codegen::CodeWriter& writer, static void get_reduce_to_scalar_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
......
...@@ -1583,7 +1583,32 @@ void runtime::gpu::GPU_Emitter::emit_Tanh(EMIT_ARGS) ...@@ -1583,7 +1583,32 @@ void runtime::gpu::GPU_Emitter::emit_Tanh(EMIT_ARGS)
void runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS) void runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); if (out[0].get_size() == 0)
{
return;
}
auto topk = static_cast<const ngraph::op::TopK*>(node);
size_t topk_axis = topk->get_top_k_axis();
size_t topk_k = topk->get_k();
auto index_elem_type = topk->get_index_element_type();
bool compute_max = topk->get_compute_max();
std::vector<element::Type> dtypes{args[0].get_element_type()};
NGRAPH_ASSERT(out.size() == 2) << "TopK can only have 2 outputs";
for (size_t i = 0; i < out.size(); i++)
{
dtypes.push_back(out[i].get_element_type());
}
auto& input_shape = args[0].get_shape();
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_topk(
dtypes, input_shape, topk_axis, topk_k, index_elem_type, compute_max);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
} }
string runtime::gpu::GPU_Emitter::node_names(const vector<GPUTensorWrapper>& args, string runtime::gpu::GPU_Emitter::node_names(const vector<GPUTensorWrapper>& args,
......
...@@ -572,8 +572,8 @@ void runtime::gpu::GPU_ExternalFunction::compile() ...@@ -572,8 +572,8 @@ void runtime::gpu::GPU_ExternalFunction::compile()
#endif #endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>(); pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this); pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment); pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>( pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
......
...@@ -81,7 +81,7 @@ namespace ngraph ...@@ -81,7 +81,7 @@ namespace ngraph
// Retrieve the kernel parameter signature given the added kernel arguments. // Retrieve the kernel parameter signature given the added kernel arguments.
// //
std::string get_input_signature(); std::string get_input_signature();
size_t get_size() { return m_argument_list.size(); }
private: private:
// //
// Cache the host argument for persistence, add it to the argument list, // Cache the host argument for persistence, add it to the argument list,
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
#include <typeinfo> #include <typeinfo>
#include "gpu_layout.hpp" #include "gpu_layout.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/gpu/gpu_op_annotations.hpp" #include "ngraph/runtime/gpu/gpu_op_annotations.hpp"
using namespace std; using namespace std;
...@@ -79,6 +81,97 @@ namespace ngraph ...@@ -79,6 +81,97 @@ namespace ngraph
reshape->set_op_annotations(op_annotations); reshape->set_op_annotations(op_annotations);
} }
} }
template <>
void GPULayout::LAYOUT_DECL(ngraph::op::TopK)
{
auto topk = std::dynamic_pointer_cast<ngraph::op::TopK>(node);
auto topk_axis = topk->get_top_k_axis();
auto topk_k = topk->get_k();
auto parent_node = topk->get_argument(0);
auto in_shape = topk->get_input_shape(0);
size_t ndim = in_shape.size();
if (in_shape.size() <= 2 && topk_axis == ndim - 1)
{
return;
}
else
{
auto out_shape = in_shape;
out_shape[topk_axis] = topk_k;
AxisVector reshape_axis_order = ngraph::get_default_order(ndim);
reshape_axis_order.erase(reshape_axis_order.begin() + topk_axis);
reshape_axis_order.push_back(topk_axis);
Shape pre_reshape_out;
for (size_t j = 0; j < ndim; j++)
{
pre_reshape_out.push_back(in_shape[reshape_axis_order[j]]);
}
Shape pre_2d_reshape_out(2);
pre_2d_reshape_out[1] = pre_reshape_out[ndim - 1];
pre_2d_reshape_out[0] =
ngraph::shape_size(pre_reshape_out) / pre_2d_reshape_out[1];
auto pre_reshape = make_shared<ngraph::op::Reshape>(
parent_node, reshape_axis_order, pre_reshape_out);
AxisVector axis_order = ngraph::get_default_order(ndim);
auto pre_2d_reshape = make_shared<ngraph::op::Reshape>(
pre_reshape, axis_order, pre_2d_reshape_out);
insert_new_node_between(parent_node, topk, pre_reshape);
insert_new_node_between(pre_reshape, topk, pre_2d_reshape);
NodeVector goes = op::get_output_elements(topk);
auto new_topk =
make_shared<ngraph::op::TopK>(pre_2d_reshape,
1,
topk->get_index_element_type(),
topk->get_k(),
topk->get_compute_max());
ngraph::replace_node(topk, new_topk);
// Replace old goe with new goe based on new topk
NodeVector new_goes;
for (auto& goe : goes)
{
auto out_idx =
std::dynamic_pointer_cast<op::GetOutputElement>(goe)->get_n();
auto new_goe =
std::make_shared<op::GetOutputElement>(new_topk, out_idx);
ngraph::replace_node(goe, new_goe);
new_goes.push_back(new_goe);
}
Shape reordered_out_shape;
for (size_t j = 0; j < ndim; j++)
{
reordered_out_shape.push_back(out_shape[reshape_axis_order[j]]);
}
NodeVector post_2d_reshapes = insert_new_reshape_after(
new_goes, AxisVector{0, 1}, reordered_out_shape);
axis_order.pop_back();
axis_order.insert(axis_order.begin() + topk_axis, 1, ndim - 1);
insert_new_reshape_after(post_2d_reshapes, axis_order, out_shape);
}
}
NodeVector insert_new_reshape_after(NodeVector& parents,
const AxisVector& axis_vector,
const Shape& out_shape)
{
NodeVector reshapes;
for (auto& parent : parents)
{
for (auto node : parent->get_users())
{
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (node->get_argument(i) == parent)
{
auto new_reshape = make_shared<ngraph::op::Reshape>(
parent, axis_vector, out_shape);
node->get_inputs().at(i).replace_output(
new_reshape->get_outputs().at(0));
reshapes.push_back(new_reshape);
}
}
}
}
return reshapes;
}
} }
} }
} }
...@@ -90,6 +183,7 @@ static const runtime::gpu::pass::LayoutOpMap s_dispatcher{ ...@@ -90,6 +183,7 @@ static const runtime::gpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ReplaceSlice), {TI(ngraph::op::ReplaceSlice),
&runtime::gpu::pass::GPULayout::layout<ngraph::op::ReplaceSlice>}, &runtime::gpu::pass::GPULayout::layout<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::Reshape), &runtime::gpu::pass::GPULayout::layout<ngraph::op::Reshape>}, {TI(ngraph::op::Reshape), &runtime::gpu::pass::GPULayout::layout<ngraph::op::Reshape>},
{TI(ngraph::op::TopK), &runtime::gpu::pass::GPULayout::layout<ngraph::op::TopK>},
}; };
bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
...@@ -54,6 +54,10 @@ namespace ngraph ...@@ -54,6 +54,10 @@ namespace ngraph
private: private:
GPU_ExternalFunction* m_external_function; GPU_ExternalFunction* m_external_function;
}; };
NodeVector insert_new_reshape_after(NodeVector& parents,
const AxisVector& axis_vector,
const Shape& out_shape);
} }
} }
} }
......
...@@ -29,24 +29,6 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max ...@@ -29,24 +29,6 @@ backwards_maxpool_n2_c1_hw5_3x3_str2_max
backwards_avgpool_n1_c1_hw2x2 backwards_avgpool_n1_c1_hw2x2
backwards_avgpool_n1_c1_hw4x4 backwards_avgpool_n1_c1_hw4x4
backwards_avgpool_n2_c2_hw4x4 backwards_avgpool_n2_c2_hw4x4
topk_1d_max_all
topk_1d_max_partial
topk_1d_max_one
topk_1d_min_all
topk_1d_min_partial
topk_1d_min_one
topk_2d_max_all
topk_2d_max_partial
topk_2d_max_one
topk_2d_min_all
topk_2d_min_partial
topk_2d_min_one
topk_3d_max_all
topk_3d_max_partial
topk_3d_max_one
topk_3d_min_all
topk_3d_min_partial
topk_3d_min_one
quantize quantize
quantize_axes quantize_axes
quantize_int8 quantize_int8
......
...@@ -78,6 +78,8 @@ topk_3d_max_partial ...@@ -78,6 +78,8 @@ topk_3d_max_partial
topk_3d_min_all topk_3d_min_all
topk_3d_min_one topk_3d_min_one
topk_3d_min_partial topk_3d_min_partial
topk_5d_max_partial
topk_int64
zero_sized_abs zero_sized_abs
zero_sized_acos zero_sized_acos
zero_sized_add zero_sized_add
......
...@@ -7,3 +7,4 @@ batchnorm_fprop_inference_b2c2h2w1 ...@@ -7,3 +7,4 @@ batchnorm_fprop_inference_b2c2h2w1
batchnorm_fprop_bprop batchnorm_fprop_bprop
batchnorm_fprop_bprop_2step batchnorm_fprop_bprop_2step
computation_reuse computation_reuse
topk_int64
...@@ -112,6 +112,8 @@ set(MULTI_TEST_SRC ...@@ -112,6 +112,8 @@ set(MULTI_TEST_SRC
backend_reduce.in.cpp backend_reduce.in.cpp
backend_reshape.in.cpp backend_reshape.in.cpp
backend_sum.in.cpp backend_sum.in.cpp
backend_topk.in.cpp
backend_arg_reduce.in.cpp
backend_test.in.cpp backend_test.in.cpp
backend_unary_elementwise.in.cpp backend_unary_elementwise.in.cpp
convolution_test.in.cpp convolution_test.in.cpp
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
// Trivial case.
NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i64), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int64_t, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
{
Shape shape{4, 3}; // HW -> (0,1)
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{4, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 1}, //r0
{0, 0}, //r1
{1, 0}, //r2
{2, 1}}) //r3
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_1) // Along Height
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 1, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{1, 1}, //
{1, 3}, //
{3, 3}})
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_2) // Along Width
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 2, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 0, 1, 1}, //
{1, 0, 0, 1}, //
{0, 0, 0, 0}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0, 1, 0, 2, 1}, // img 0 ch 0
{0, 3, 2, 0, 0},
{2, 0, 0, 0, 1},
{2, 0, 1, 1, 2},
{0, 2, 1, 0, 0}},
{{0, 0, 0, 2, 0}, // img 0 ch 1
{0, 2, 3, 0, 1},
{2, 0, 1, 0, 2},
{3, 1, 0, 0, 0},
{2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, // img 1 ch 0
{0, 0, 2, 0, 1},
{0, 0, 1, 2, 3},
{2, 0, 0, 3, 0},
{0, 0, 0, 0, 0}},
{{2, 1, 0, 0, 1}, // img 1 ch 1
{0, 2, 0, 0, 0},
{1, 1, 2, 0, 2},
{1, 1, 1, 0, 1},
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{3, 1, 0, 0, 1}, {3, 2, 0, 0, 0}}, //ch0
{{1, 2, 4, 3, 0}, {0, 1, 2, 0, 4}}}) //ch1
.get_vector()),
read_vector<int>(result));
}
...@@ -4764,740 +4764,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2) ...@@ -4764,740 +4764,6 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_sequence_n4d2c3h2w2)
EXPECT_EQ(read_vector<int>(result), expected); EXPECT_EQ(read_vector<int>(result), expected);
} }
// Trivial case.
NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i64), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0.5f, 1.5f, 0.8f, 2.9f, 1.05f}, // img 0 ch 0
{0.5f, 3.5f, 2.0f, 1.0f, 0.2f},
{2.0f, 0.0f, 2.2f, 0.2f, 1.4f},
{2.9f, 0.0f, 1.52f, 1.2f, 2.22f},
{5.0f, 2.0f, 1.0f, 0.5f, 0.85f}},
{{0.25f, 0.02f, 0.02f, 2.2f, 0.001f}, // img 0 ch 1
{1.0f, 0.2f, 3.0f, 0.25f, 1.14f},
{2.25f, 10.1f, 1.0f, 0.02f, 2.22f},
{3.2f, 1.002f, 0.001f, 0.2f, 6.0f},
{2.0f, 0.0f, 0.0f, 0.0f, 0.0f}}},
{{{0.0f, 2.2f, 1.2f, 1.6f, 0.2f}, // img 1 ch 0
{0.01f, 0.0f, 0.22f, 0.02f, 1.1f},
{0.01f, 0.5f, 1.6f, 0.2f, 3.2f},
{2.4f, 0.5f, 0.0f, 3.0f, 0.1f},
{0.0f, 0.5f, 0.4f, 0.8f, 1.0f}},
{{2.0f, 1.0f, 0.0f, 0.0f, 1.0f}, // img 1 ch 1
{0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 2.0f, 0.0f, 2.0f},
{1.0f, 1.0f, 1.0f, 0.0f, 1.0f},
{1.0f, 0.0f, 0.0f, 0.0f, 2.0f}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int64_t, 3>({{{0, 4, 1, 1, 3}, // ch0
{4, 1, 3, 2, 1}}, //
{{0, 1, 0, 2, 0}, // ch1
{2, 0, 3, 3, 1}}}) //
.get_vector()),
read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
{
Shape shape{4, 3}; // HW -> (0,1)
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{4, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 1}, //r0
{0, 0}, //r1
{1, 0}, //r2
{2, 1}}) //r3
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_1) // Along Height
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 1, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{1, 1}, //
{1, 3}, //
{3, 3}})
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_2) // Along Width
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 2, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 3>({{{8, 4}, //ch0
{12, 10},
{2, 9},
{1, 5}},
{{6, 7}, //ch1
{11, 3},
{9, 2},
{10, 12}},
{{8, 4}, //ch2
{6, 1},
{5, 3},
{11, 7}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 2>({{0, 0, 1, 1}, //
{1, 0, 0, 1}, //
{0, 0, 0, 0}}) //
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_axis_3)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f =
make_shared<Function>(make_shared<op::ArgMax>(A, 3, element::i32), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a,
test::NDArray<float, 4>({{{{0, 1, 0, 2, 1}, // img 0 ch 0
{0, 3, 2, 0, 0},
{2, 0, 0, 0, 1},
{2, 0, 1, 1, 2},
{0, 2, 1, 0, 0}},
{{0, 0, 0, 2, 0}, // img 0 ch 1
{0, 2, 3, 0, 1},
{2, 0, 1, 0, 2},
{3, 1, 0, 0, 0},
{2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, // img 1 ch 0
{0, 0, 2, 0, 1},
{0, 0, 1, 2, 3},
{2, 0, 0, 3, 0},
{0, 0, 0, 0, 0}},
{{2, 1, 0, 0, 1}, // img 1 ch 1
{0, 2, 0, 0, 0},
{1, 1, 2, 0, 2},
{1, 1, 1, 0, 1},
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((test::NDArray<int, 3>({{{3, 1, 0, 0, 1}, {3, 2, 0, 0, 0}}, //ch0
{{1, 2, 4, 3, 0}, {0, 1, 2, 0, 4}}}) //ch1
.get_vector()),
read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all)
{
Shape shape{6};
Shape rshape{6};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 0, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3, 2, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6, 5, 4, 3, 2, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_partial)
{
Shape shape{6};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 3, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6, 5, 4}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_one)
{
Shape shape{6};
Shape rshape{1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_all)
{
Shape shape{6};
Shape rshape{6};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 0, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3, 2, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_partial)
{
Shape shape{6};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 3, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1, 2, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_one)
{
Shape shape{6};
Shape rshape{1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_all)
{
Shape shape{2, 3, 2};
Shape rshape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 0, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 0, 2, 2, 0, 2, 2, 0, 1, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 9, 4, 8, 2, 11, 7, 6, 3, 5, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_partial)
{
Shape shape{2, 3, 2};
Shape rshape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 2, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 0, 2, 2, 2, 0, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 9, 4, 11, 7, 6, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_one)
{
Shape shape{2, 3, 2};
Shape rshape{2, 1, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 2, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 11, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_all)
{
Shape shape{2, 3, 2};
Shape rshape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 0, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 2, 0, 1, 1, 0, 0, 1, 2, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 10, 4, 12, 9, 5, 1, 6, 3, 11, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_partial)
{
Shape shape{2, 3, 2};
Shape rshape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 2, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 2, 1, 0, 0, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 10, 4, 5, 1, 6, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_one)
{
Shape shape{2, 3, 2};
Shape rshape{2, 1, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 5, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_all)
{
Shape shape{4, 3};
Shape rshape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 4, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0, 0, 1, 3, 2, 0, 2, 3, 2, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10, 9, 8, 7, 6, 2, 5, 3, 1, 4}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_partial)
{
Shape shape{4, 3};
Shape rshape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 2, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0, 0, 1, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10, 9, 8, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_one)
{
Shape shape{4, 3};
Shape rshape{1, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_all)
{
Shape shape{4, 3};
Shape rshape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 4, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1, 2, 0, 2, 1, 1, 3, 0, 3, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4, 6, 2, 5, 9, 8, 7, 12, 11, 10}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_partial)
{
Shape shape{4, 3};
Shape rshape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 2, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1, 2, 0, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4, 6, 2, 5}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_one)
{
Shape shape{4, 3};
Shape rshape{1, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, quantize) NGRAPH_TEST(${BACKEND_NAME}, quantize)
{ {
Shape input_shape{4, 3}; Shape input_shape{4, 3};
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all)
{
Shape shape{6};
Shape rshape{6};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 0, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3, 2, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6, 5, 4, 3, 2, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_partial)
{
Shape shape{6};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 3, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6, 5, 4}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_one)
{
Shape shape{6};
Shape rshape{1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{6}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_all)
{
Shape shape{6};
Shape rshape{6};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 0, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3, 2, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1, 2, 3, 4, 5, 6}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_partial)
{
Shape shape{6};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 3, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1, 2, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_min_one)
{
Shape shape{6};
Shape rshape{1};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{6, 5, 4, 3, 2, 1});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{5}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_all)
{
Shape shape{2, 3, 2};
Shape rshape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 0, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 0, 2, 2, 0, 2, 2, 0, 1, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 9, 4, 8, 2, 11, 7, 6, 3, 5, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_int64)
{
Shape shape{2, 3, 2};
Shape rshape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i64, 0, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i64, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int64_t>{1, 1, 0, 2, 2, 0, 2, 2, 0, 1, 1, 0}), read_vector<int64_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 9, 4, 8, 2, 11, 7, 6, 3, 5, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_5d_max_partial)
{
Shape shape{2, 6, 3, 2, 4};
Shape rshape{2, 2, 3, 2, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 2, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(
a,
vector<float>{
1., 73., 9., 81., 17., 89., 2., 74., 10., 82., 18., 90., 3., 75.,
11., 83., 19., 91., 4., 76., 12., 84., 20., 92., 145., 217., 153., 225.,
161., 233., 146., 218., 154., 226., 162., 234., 147., 219., 155., 227., 163., 235.,
148., 220., 156., 228., 164., 236., 5., 77., 13., 85., 21., 93., 6., 78.,
14., 86., 22., 94., 7., 79., 15., 87., 23., 95., 8., 80., 16., 88.,
24., 96., 149., 221., 157., 229., 165., 27., 150., 222., 158., 230., 166., 23.,
151., 223., 159., 231., 17., 39., 2., 224., 160., 232., 168., 240., 25., 97.,
33., 105., 41., 113., 26., 98., 34., 106., 42., 114., 27., 99., 35., 107.,
43., 115., 28., 100., 36., 108., 44., 116., 169., 241., 177., 249., 185., 25.,
170., 242., 178., 250., 186., 258., 171., 243., 179., 251., 187., 259., 172., 24.,
180., 252., 188., 260., 29., 101., 37., 109., 45., 117., 30., 102., 38., 10.,
46., 118., 31., 103., 39., 111., 47., 119., 32., 104., 40., 112., 48., 20.,
173., 245., 181., 253., 189., 261., 174., 246., 182., 254., 190., 262., 175., 27.,
183., 255., 191., 263., 176., 248., 184., 256., 192., 264., 49., 121., 57., 129.,
65., 137., 50., 122., 58., 130., 66., 138., 51., 123., 59., 131., 67., 139.,
52., 124., 60., 132., 68., 140., 193., 265., 201., 273., 209., 281., 194., 266.,
202., 274., 210., 43., 115., 28., 100., 36., 108., 44., 116., 169., 241., 177.,
212., 284., 53., 125., 61., 133., 69., 141., 54., 126., 62., 134., 70., 142.,
55., 127., 63., 135., 71., 143., 56., 128., 64., 136., 72., 144., 197., 269.,
205., 277., 213., 285., 198., 270., 206., 278., 214., 286., 199., 271., 207., 279.,
215., 287., 200., 272., 208., 280., 216., 288.});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ(
(vector<int32_t>{5, 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5, 5,
3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 4, 1, 1, 1, 1, 1, 1, 5, 1, 3, 3}),
read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{169, 241, 177, 249, 185, 233, 170, 242, 178, 250, 186, 258, 171, 243,
179, 251, 187, 259, 172, 224, 180, 252, 188, 260, 149, 221, 157, 229,
165, 113, 150, 222, 158, 230, 166, 234, 151, 223, 159, 231, 163, 235,
148, 220, 160, 232, 168, 240, 197, 269, 205, 277, 213, 285, 198, 270,
206, 278, 214, 286, 199, 271, 207, 279, 215, 287, 200, 272, 241, 280,
216, 288, 193, 265, 201, 273, 209, 281, 194, 266, 202, 274, 210, 262,
175, 127, 183, 255, 191, 263, 176, 248, 208, 256, 212, 284}),
read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_partial)
{
Shape shape{2, 3, 2};
Shape rshape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 2, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 0, 2, 2, 2, 0, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 9, 4, 11, 7, 6, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_max_one)
{
Shape shape{2, 3, 2};
Shape rshape{2, 1, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 1, 2, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{10, 12, 11, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_all)
{
Shape shape{2, 3, 2};
Shape rshape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 0, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 2, 0, 1, 1, 0, 0, 1, 2, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 10, 4, 12, 9, 5, 1, 6, 3, 11, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_partial)
{
Shape shape{2, 3, 2};
Shape rshape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 2, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 2, 1, 0, 0, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 10, 4, 5, 1, 6, 3}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_3d_min_one)
{
Shape shape{2, 3, 2};
Shape rshape{2, 1, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 1, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{2, 0, 1, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{8, 2, 5, 1}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_all)
{
Shape shape{4, 3};
Shape rshape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 4, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0, 0, 1, 3, 2, 0, 2, 3, 2, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10, 9, 8, 7, 6, 2, 5, 3, 1, 4}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_partial)
{
Shape shape{4, 3};
Shape rshape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 2, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0, 0, 1, 3}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10, 9, 8, 7}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_max_one)
{
Shape shape{4, 3};
Shape rshape{1, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, true);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{9, 2, 10, 12, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{1, 3, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{12, 11, 10}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_all)
{
Shape shape{4, 3};
Shape rshape{4, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 4, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1, 2, 0, 2, 1, 1, 3, 0, 3, 0}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4, 6, 2, 5, 9, 8, 7, 12, 11, 10}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_partial)
{
Shape shape{4, 3};
Shape rshape{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 2, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1, 2, 0, 2}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4, 6, 2, 5}), read_vector<float>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_2d_min_one)
{
Shape shape{4, 3};
Shape rshape{1, 3};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 1, false);
auto f0 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), op::ParameterVector{A});
auto f1 =
make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(a, vector<float>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::f32, rshape);
backend->call_with_validate(f0, {result0}, {a});
EXPECT_EQ((vector<int32_t>{3, 2, 1}), read_vector<int32_t>(result0));
backend->call_with_validate(f1, {result1}, {a});
EXPECT_EQ((vector<float>{3, 1, 4}), read_vector<float>(result1));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment