Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
015e1da8
Unverified
Commit
015e1da8
authored
Apr 03, 2018
by
adstraw
Committed by
GitHub
Apr 03, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
more flexible tensor mask (#803)
parent
07cc9616
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
86 deletions
+70
-86
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-1
tensor_mask.cpp
src/ngraph/builder/tensor_mask.cpp
+0
-83
tensor_mask.hpp
src/ngraph/builder/tensor_mask.hpp
+68
-1
builder.cpp
test/builder.cpp
+2
-1
No files found.
src/ngraph/CMakeLists.txt
View file @
015e1da8
...
...
@@ -19,7 +19,6 @@ set (SRC
builder/autobroadcast.cpp
builder/numpy_transpose.cpp
builder/reduce_ops.cpp
builder/tensor_mask.cpp
coordinate_transform.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
...
...
src/ngraph/builder/tensor_mask.cpp
deleted
100644 → 0
View file @
07cc9616
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <numeric>
#include "ngraph/builder/tensor_mask.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/reshape.hpp"
using
namespace
ngraph
;
std
::
shared_ptr
<
Node
>
ngraph
::
builder
::
tensor_mask
(
const
std
::
shared_ptr
<
Node
>&
sequence_lengths
,
size_t
sequence_axis
,
size_t
batch_axis
,
Shape
mask_shape
)
{
if
(
sequence_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
if
(
batch_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
// all axes except the sequence axis
AxisSet
non_sequence_axes
;
// all axes except the batch axis
AxisSet
non_batch_axes
;
for
(
auto
axis
=
0
;
axis
<
mask_shape
.
size
();
++
axis
)
{
if
(
axis
!=
sequence_axis
)
{
non_sequence_axes
.
insert
(
axis
);
}
if
(
axis
!=
batch_axis
)
{
non_batch_axes
.
insert
(
axis
);
}
}
// broadcast sequence lengths to mask shape along all non-batch axes
auto
broadcast_sequence_lengths
=
std
::
make_shared
<
op
::
Broadcast
>
(
sequence_lengths
,
mask_shape
,
non_batch_axes
);
// create sequence data [0, ..., max_sequence_length]
auto
max_sequence_length
=
mask_shape
[
sequence_axis
];
std
::
vector
<
uint32_t
>
sequence_data
(
max_sequence_length
);
std
::
iota
(
sequence_data
.
begin
(),
sequence_data
.
end
(),
0
);
// create sequence constant
auto
sequence
=
std
::
make_shared
<
op
::
Constant
>
(
element
::
u32
,
Shape
{
max_sequence_length
},
sequence_data
);
// convert sequence to input type
auto
convert_sequence
=
std
::
make_shared
<
op
::
Convert
>
(
sequence
,
sequence_lengths
->
get_element_type
());
// broadcast sequence to mask shape along all non-sequence axes
auto
broadcast_sequence
=
std
::
make_shared
<
op
::
Broadcast
>
(
convert_sequence
,
mask_shape
,
non_sequence_axes
);
// mask = sequence_length < sequence
return
std
::
make_shared
<
op
::
Less
>
(
broadcast_sequence
,
broadcast_sequence_lengths
);
}
src/ngraph/builder/tensor_mask.hpp
View file @
015e1da8
...
...
@@ -18,15 +18,82 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
builder
{
// batch_size = mask_shape on the batch_axis
// max_sequence_length = mask_shape on the sequence_axis
// sequence_lengths = list of lengths < max_sequence_length of shape batch_size
// a mask is created by...
// 1. creating a sequence starting at sequence_begin of shape max_sequence_length
// 2. broadcasting that sequence along all non-sequence axes to mask_shape
// 3. broadcasting sequence_lengths along all non-batch axes to mask_shape
// 4. returning the specified binary element-wise operation T #2 and #3
template
<
class
T
>
std
::
shared_ptr
<
Node
>
tensor_mask
(
const
std
::
shared_ptr
<
Node
>&
sequence_lengths
,
size_t
sequence_axis
,
size_t
batch_axis
,
Shape
mask_shape
);
ngraph
::
Shape
mask_shape
,
uint32_t
sequence_begin
)
{
if
(
sequence_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
if
(
batch_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
// all axes except the sequence axis
ngraph
::
AxisSet
non_sequence_axes
;
// all axes except the batch axis
ngraph
::
AxisSet
non_batch_axes
;
for
(
size_t
axis
=
0
;
axis
<
mask_shape
.
size
();
++
axis
)
{
if
(
axis
!=
sequence_axis
)
{
non_sequence_axes
.
insert
(
axis
);
}
if
(
axis
!=
batch_axis
)
{
non_batch_axes
.
insert
(
axis
);
}
}
// broadcast sequence lengths to mask shape along all non-batch axes
auto
broadcast_sequence_lengths
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
sequence_lengths
,
mask_shape
,
non_batch_axes
);
// create sequence data [0, ..., max_sequence_length]
auto
max_sequence_length
=
mask_shape
[
sequence_axis
];
std
::
vector
<
uint32_t
>
sequence_data
(
max_sequence_length
);
std
::
iota
(
sequence_data
.
begin
(),
sequence_data
.
end
(),
sequence_begin
);
// create sequence constant
auto
sequence
=
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
u32
,
Shape
{
max_sequence_length
},
sequence_data
);
// convert sequence to input type
auto
convert_sequence
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
sequence
,
sequence_lengths
->
get_element_type
());
// broadcast sequence to mask shape along all non-sequence axes
auto
broadcast_sequence
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
convert_sequence
,
mask_shape
,
non_sequence_axes
);
// mask = sequence_length < sequence
return
std
::
make_shared
<
T
>
(
broadcast_sequence
,
broadcast_sequence_lengths
);
}
}
}
test/builder.cpp
View file @
015e1da8
...
...
@@ -147,7 +147,8 @@ TEST(builder, tensor_mask)
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
u32
,
max_sequence_length
);
Shape
mask_shape
{
3
,
5
};
auto
f
=
make_shared
<
Function
>
(
builder
::
tensor_mask
(
sequence_lengths
,
1
,
0
,
mask_shape
),
auto
f
=
make_shared
<
Function
>
(
builder
::
tensor_mask
<
op
::
Less
>
(
sequence_lengths
,
1
,
0
,
mask_shape
,
0
),
op
::
ParameterVector
{
sequence_lengths
});
auto
manager
=
runtime
::
Manager
::
get
(
"INTERPRETER"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment