Commit 879f9492 authored by Christian Convey's avatar Christian Convey Committed by Scott Cyphers

Add naive int64-indexing to EmbeddingLookup (#2644)

parent 857093c1
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <cstring>
#include "ngraph/op/embedding_lookup.hpp"
......@@ -77,6 +78,19 @@ namespace ngraph
in_shape);
};
}
else if (index_element_type == element::i64)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<float, int64_t>(
static_cast<int64_t*>(arg0_tensor),
static_cast<float*>(arg1_tensor),
static_cast<float*>(out_tensor),
element_count,
in_shape);
};
}
else
{
throw ngraph_error(
......@@ -111,6 +125,19 @@ namespace ngraph
in_shape);
};
}
else if (index_element_type == element::i64)
{
functor = [&, in_shape, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::embedding<int, int64_t>(
static_cast<int64_t*>(arg0_tensor),
static_cast<int*>(arg1_tensor),
static_cast<int*>(out_tensor),
element_count,
in_shape);
};
}
else
{
throw ngraph_error(
......@@ -119,7 +146,7 @@ namespace ngraph
}
else
{
throw ngraph_error("Unsupported type in CPU Builder for ArgMin");
throw ngraph_error("Unsupported type in CPU Builder for EmbeddingLookup");
}
functors.emplace_back(functor);
......
......@@ -14,6 +14,7 @@ backwards_avgpool_n2_c2_hw4x4
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
batch_norm_inference_0eps_f64
batch_norm_inference_0eps_f32
batch_norm_inference_f64
......
......@@ -8,6 +8,7 @@ backwards_slice
batch_norm_bprop_n4c3h2w2
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
embedding_lookup_4x5_reverse
generate_mask
replace_slice_3d
......
......@@ -100,4 +100,5 @@ sum_stable_acc_double # To debug: precision errors
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
floor_int32
......@@ -106,3 +106,27 @@ NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_10x1_arbitrary_index_type_int)
vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result0), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_10x1_arbitrary_index_type_int64)
{
Shape shape{10};
Shape rshape{10, 1};
auto A = make_shared<op::Parameter>(element::i64, shape);
auto B = make_shared<op::Parameter>(element::f32, rshape);
auto embed = make_shared<op::EmbeddingLookup>(A, B);
auto f0 = make_shared<Function>(NodeVector{embed}, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i64, shape);
copy_data(a, vector<int64_t>{9, 2, 1, 0, 3, 5, 4, 6, 8, 7});
auto b = backend->create_tensor(element::f32, rshape);
copy_data(b, vector<float>{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
auto result0 = backend->create_tensor(element::f32, rshape);
auto handle = backend->compile(f0);
handle->call_with_validate({result0}, {a, b});
//vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
EXPECT_TRUE(test::all_close_f(expected, read_vector<float>(result0), MIN_FLOAT_TOLERANCE_BITS));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment