Commit 23c11c3d authored by Dina Suehiro Jones's avatar Dina Suehiro Jones Committed by Scott Cyphers

Update cpu backend to support i32 with topk (#3415)

* Update cpu backend to support i32 with topk

* Add topk i32 unit test

* Minor fix in unit-test

* disable test for plaidml
parent 67367211
......@@ -152,9 +152,59 @@ namespace ngraph
};
}
}
else if (element_type == element::i32)
{
if (is_int64)
{
functor = [&,
in_shape,
out_shape,
axis,
k,
compute_max,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::topk<int32_t, int64_t>(
static_cast<int32_t*>(ctx->buffer_data[arg_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[out_indices_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[out_values_buffer_index]),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
else
{
functor = [&,
in_shape,
out_shape,
axis,
k,
compute_max,
arg_buffer_index,
out_indices_buffer_index,
out_values_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::topk<int32_t, int32_t>(
static_cast<int32_t*>(ctx->buffer_data[arg_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[out_indices_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[out_values_buffer_index]),
in_shape,
out_shape,
axis,
k,
compute_max);
};
}
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for TopK");
throw ngraph_error("Unsupported type (" + element_type.get_type_name() +
") in CPU Builder for TopK");
}
functors.emplace_back(functor);
......
......@@ -38,6 +38,8 @@ topk_2d_min_partial # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
topk_1d_i32_max_all # No plans to implement TopK
# Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
......
......@@ -61,6 +61,31 @@ NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_all)
(vector<float>{6, 5, 4, 3, 2, 1}), read_vector<float>(result1), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_i32_max_all)
{
Shape shape{6};
Shape rshape{6};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto B = make_shared<op::TopK>(A, 0, element::i32, 0, true);
auto f0 = make_shared<Function>(make_shared<op::GetOutputElement>(B, 0), ParameterVector{A});
auto f1 = make_shared<Function>(make_shared<op::GetOutputElement>(B, 1), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto result0 = backend->create_tensor(element::i32, rshape);
auto result1 = backend->create_tensor(element::i32, rshape);
auto h0 = backend->compile(f0);
h0->call_with_validate({result0}, {a});
EXPECT_EQ((vector<int32_t>{5, 4, 3, 2, 1, 0}), read_vector<int32_t>(result0));
auto h1 = backend->compile(f1);
h1->call_with_validate({result1}, {a});
EXPECT_EQ((vector<int32_t>{6, 5, 4, 3, 2, 1}), read_vector<int32_t>(result1));
}
NGRAPH_TEST(${BACKEND_NAME}, topk_1d_max_partial)
{
Shape shape{6};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment