Unverified Commit 67328703 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add a new onehot unit test (#2670)

parent db610006
......@@ -321,3 +321,36 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_1_fp_nonint)
FAIL() << "Expected a std::out_of_range exception";
}
}
NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_many_categories)
{
// Imagenet has roughly 20,000 categories
uint32_t category_count = 20000;
Shape shape_a{6};
auto A = make_shared<op::Parameter>(element::i32, shape_a);
Shape shape_r{6, category_count};
auto r = make_shared<op::OneHot>(A, Shape{6, category_count}, 1);
auto f = make_shared<Function>(r, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape_a);
vector<int32_t> input_data{0, 11, 101, 1001, 10001, static_cast<int32_t>(category_count - 1)};
copy_data(a, input_data);
auto result = backend->create_tensor(element::i32, shape_r);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
vector<int32_t> data = read_vector<int32_t>(result);
vector<int32_t> bit_positions;
for (size_t i = 0; i < shape_size(shape_r); ++i)
{
if (data[i] == 1)
{
bit_positions.push_back(i % category_count);
}
}
EXPECT_EQ(bit_positions, input_data);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment