Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
015e1da8
Unverified
Commit
015e1da8
authored
7 years ago
by
adstraw
Committed by
GitHub
7 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
more flexible tensor mask (#803)
parent
07cc9616
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
87 deletions
+71
-87
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-1
tensor_mask.cpp
src/ngraph/builder/tensor_mask.cpp
+0
-83
tensor_mask.hpp
src/ngraph/builder/tensor_mask.hpp
+68
-1
builder.cpp
test/builder.cpp
+3
-2
No files found.
src/ngraph/CMakeLists.txt
View file @
015e1da8
...
...
@@ -19,7 +19,6 @@ set (SRC
builder/autobroadcast.cpp
builder/numpy_transpose.cpp
builder/reduce_ops.cpp
builder/tensor_mask.cpp
coordinate_transform.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/builder/tensor_mask.cpp
deleted
100644 → 0
View file @
07cc9616
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <numeric>
#include "ngraph/builder/tensor_mask.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/reshape.hpp"
using
namespace
ngraph
;
std
::
shared_ptr
<
Node
>
ngraph
::
builder
::
tensor_mask
(
const
std
::
shared_ptr
<
Node
>&
sequence_lengths
,
size_t
sequence_axis
,
size_t
batch_axis
,
Shape
mask_shape
)
{
if
(
sequence_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
if
(
batch_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
// all axes except the sequence axis
AxisSet
non_sequence_axes
;
// all axes except the batch axis
AxisSet
non_batch_axes
;
for
(
auto
axis
=
0
;
axis
<
mask_shape
.
size
();
++
axis
)
{
if
(
axis
!=
sequence_axis
)
{
non_sequence_axes
.
insert
(
axis
);
}
if
(
axis
!=
batch_axis
)
{
non_batch_axes
.
insert
(
axis
);
}
}
// broadcast sequence lengths to mask shape along all non-batch axes
auto
broadcast_sequence_lengths
=
std
::
make_shared
<
op
::
Broadcast
>
(
sequence_lengths
,
mask_shape
,
non_batch_axes
);
// create sequence data [0, ..., max_sequence_length]
auto
max_sequence_length
=
mask_shape
[
sequence_axis
];
std
::
vector
<
uint32_t
>
sequence_data
(
max_sequence_length
);
std
::
iota
(
sequence_data
.
begin
(),
sequence_data
.
end
(),
0
);
// create sequence constant
auto
sequence
=
std
::
make_shared
<
op
::
Constant
>
(
element
::
u32
,
Shape
{
max_sequence_length
},
sequence_data
);
// convert sequence to input type
auto
convert_sequence
=
std
::
make_shared
<
op
::
Convert
>
(
sequence
,
sequence_lengths
->
get_element_type
());
// broadcast sequence to mask shape along all non-sequence axes
auto
broadcast_sequence
=
std
::
make_shared
<
op
::
Broadcast
>
(
convert_sequence
,
mask_shape
,
non_sequence_axes
);
// mask = sequence_length < sequence
return
std
::
make_shared
<
op
::
Less
>
(
broadcast_sequence
,
broadcast_sequence_lengths
);
}
This diff is collapsed.
Click to expand it.
src/ngraph/builder/tensor_mask.hpp
View file @
015e1da8
...
...
@@ -18,15 +18,82 @@
#include "ngraph/axis_set.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
builder
{
// batch_size = mask_shape on the batch_axis
// max_sequence_length = mask_shape on the sequence_axis
// sequence_lengths = list of lengths < max_sequence_length of shape batch_size
// a mask is created by...
// 1. creating a sequence starting at sequence_begin of shape max_sequence_length
// 2. broadcasting that sequence along all non-sequence axes to mask_shape
// 3. broadcasting sequence_lengths along all non-batch axes to mask_shape
// 4. returning the specified binary element-wise operation T #2 and #3
template
<
class
T
>
std
::
shared_ptr
<
Node
>
tensor_mask
(
const
std
::
shared_ptr
<
Node
>&
sequence_lengths
,
size_t
sequence_axis
,
size_t
batch_axis
,
Shape
mask_shape
);
ngraph
::
Shape
mask_shape
,
uint32_t
sequence_begin
)
{
if
(
sequence_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
if
(
batch_axis
>=
mask_shape
.
size
())
{
throw
ngraph_error
(
"Sequence axis must be in range 0..mask_shape rank"
);
}
// all axes except the sequence axis
ngraph
::
AxisSet
non_sequence_axes
;
// all axes except the batch axis
ngraph
::
AxisSet
non_batch_axes
;
for
(
size_t
axis
=
0
;
axis
<
mask_shape
.
size
();
++
axis
)
{
if
(
axis
!=
sequence_axis
)
{
non_sequence_axes
.
insert
(
axis
);
}
if
(
axis
!=
batch_axis
)
{
non_batch_axes
.
insert
(
axis
);
}
}
// broadcast sequence lengths to mask shape along all non-batch axes
auto
broadcast_sequence_lengths
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
sequence_lengths
,
mask_shape
,
non_batch_axes
);
// create sequence data [0, ..., max_sequence_length]
auto
max_sequence_length
=
mask_shape
[
sequence_axis
];
std
::
vector
<
uint32_t
>
sequence_data
(
max_sequence_length
);
std
::
iota
(
sequence_data
.
begin
(),
sequence_data
.
end
(),
sequence_begin
);
// create sequence constant
auto
sequence
=
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
u32
,
Shape
{
max_sequence_length
},
sequence_data
);
// convert sequence to input type
auto
convert_sequence
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
sequence
,
sequence_lengths
->
get_element_type
());
// broadcast sequence to mask shape along all non-sequence axes
auto
broadcast_sequence
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
convert_sequence
,
mask_shape
,
non_sequence_axes
);
// mask = sequence_length < sequence
return
std
::
make_shared
<
T
>
(
broadcast_sequence
,
broadcast_sequence_lengths
);
}
}
}
This diff is collapsed.
Click to expand it.
test/builder.cpp
View file @
015e1da8
...
...
@@ -147,8 +147,9 @@ TEST(builder, tensor_mask)
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
u32
,
max_sequence_length
);
Shape
mask_shape
{
3
,
5
};
auto
f
=
make_shared
<
Function
>
(
builder
::
tensor_mask
(
sequence_lengths
,
1
,
0
,
mask_shape
),
op
::
ParameterVector
{
sequence_lengths
});
auto
f
=
make_shared
<
Function
>
(
builder
::
tensor_mask
<
op
::
Less
>
(
sequence_lengths
,
1
,
0
,
mask_shape
,
0
),
op
::
ParameterVector
{
sequence_lengths
});
auto
manager
=
runtime
::
Manager
::
get
(
"INTERPRETER"
);
auto
external
=
manager
->
compile
(
f
);
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment