Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
879f9492
Commit
879f9492
authored
Mar 27, 2019
by
Christian Convey
Committed by
Scott Cyphers
Mar 27, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add naive int64-indexing to EmbeddingLookup (#2644)
parent
857093c1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
1 deletion
+55
-1
embedding_lookup.cpp
src/ngraph/runtime/cpu/builder/embedding_lookup.cpp
+28
-1
unit_test.manifest
src/ngraph/runtime/gpu/unit_test.manifest
+1
-0
unit_test.manifest
src/ngraph/runtime/intelgpu/unit_test.manifest
+1
-0
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+1
-0
backend_embedding_lookup.in.cpp
test/backend_embedding_lookup.in.cpp
+24
-0
No files found.
src/ngraph/runtime/cpu/builder/embedding_lookup.cpp
View file @
879f9492
...
...
@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <cstring>
#include "ngraph/op/embedding_lookup.hpp"
...
...
@@ -77,6 +78,19 @@ namespace ngraph
in_shape
);
};
}
else
if
(
index_element_type
==
element
::
i64
)
{
functor
=
[
&
,
in_shape
,
element_count
](
CPURuntimeContext
*
ctx
,
CPUExecutionContext
*
ectx
)
{
ngraph
::
runtime
::
reference
::
embedding
<
float
,
int64_t
>
(
static_cast
<
int64_t
*>
(
arg0_tensor
),
static_cast
<
float
*>
(
arg1_tensor
),
static_cast
<
float
*>
(
out_tensor
),
element_count
,
in_shape
);
};
}
else
{
throw
ngraph_error
(
...
...
@@ -111,6 +125,19 @@ namespace ngraph
in_shape
);
};
}
else
if
(
index_element_type
==
element
::
i64
)
{
functor
=
[
&
,
in_shape
,
element_count
](
CPURuntimeContext
*
ctx
,
CPUExecutionContext
*
ectx
)
{
ngraph
::
runtime
::
reference
::
embedding
<
int
,
int64_t
>
(
static_cast
<
int64_t
*>
(
arg0_tensor
),
static_cast
<
int
*>
(
arg1_tensor
),
static_cast
<
int
*>
(
out_tensor
),
element_count
,
in_shape
);
};
}
else
{
throw
ngraph_error
(
...
...
@@ -119,7 +146,7 @@ namespace ngraph
}
else
{
throw
ngraph_error
(
"Unsupported type in CPU Builder for
ArgMin
"
);
throw
ngraph_error
(
"Unsupported type in CPU Builder for
EmbeddingLookup
"
);
}
functors
.
emplace_back
(
functor
);
...
...
src/ngraph/runtime/gpu/unit_test.manifest
View file @
879f9492
...
...
@@ -14,6 +14,7 @@ backwards_avgpool_n2_c2_hw4x4
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
batch_norm_inference_0eps_f64
batch_norm_inference_0eps_f32
batch_norm_inference_f64
...
...
src/ngraph/runtime/intelgpu/unit_test.manifest
View file @
879f9492
...
...
@@ -8,6 +8,7 @@ backwards_slice
batch_norm_bprop_n4c3h2w2
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
embedding_lookup_4x5_reverse
generate_mask
replace_slice_3d
...
...
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
879f9492
...
...
@@ -100,4 +100,5 @@ sum_stable_acc_double # To debug: precision errors
embedding_lookup_4x5_reverse
embedding_lookup_10x1_arbitrary
embedding_lookup_10x1_arbitrary_index_type_int
embedding_lookup_10x1_arbitrary_index_type_int64
floor_int32
test/backend_embedding_lookup.in.cpp
View file @
879f9492
...
...
@@ -106,3 +106,27 @@ NGRAPH_TEST(${BACKEND_NAME}, embedding_lookup_10x1_arbitrary_index_type_int)
vector
<
float
>
expected
{
9.5
,
2.5
,
1.5
,
0.5
,
3.5
,
5.5
,
4.5
,
6.5
,
8.5
,
7.5
};
EXPECT_TRUE
(
test
::
all_close_f
(
expected
,
read_vector
<
float
>
(
result0
),
MIN_FLOAT_TOLERANCE_BITS
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
embedding_lookup_10x1_arbitrary_index_type_int64
)
{
Shape
shape
{
10
};
Shape
rshape
{
10
,
1
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
rshape
);
auto
embed
=
make_shared
<
op
::
EmbeddingLookup
>
(
A
,
B
);
auto
f0
=
make_shared
<
Function
>
(
NodeVector
{
embed
},
ParameterVector
{
A
,
B
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i64
,
shape
);
copy_data
(
a
,
vector
<
int64_t
>
{
9
,
2
,
1
,
0
,
3
,
5
,
4
,
6
,
8
,
7
});
auto
b
=
backend
->
create_tensor
(
element
::
f32
,
rshape
);
copy_data
(
b
,
vector
<
float
>
{
0.5
,
1.5
,
2.5
,
3.5
,
4.5
,
5.5
,
6.5
,
7.5
,
8.5
,
9.5
});
auto
result0
=
backend
->
create_tensor
(
element
::
f32
,
rshape
);
auto
handle
=
backend
->
compile
(
f0
);
handle
->
call_with_validate
({
result0
},
{
a
,
b
});
//vector<float> expected{9.5, 2.5, 1.5, 0.5, 3.5, 5.5, 4.5, 6.5, 8.5, 7.5};
vector
<
float
>
expected
{
9.5
,
2.5
,
1.5
,
0.5
,
3.5
,
5.5
,
4.5
,
6.5
,
8.5
,
7.5
};
EXPECT_TRUE
(
test
::
all_close_f
(
expected
,
read_vector
<
float
>
(
result0
),
MIN_FLOAT_TOLERANCE_BITS
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment