Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
080d4f95
Commit
080d4f95
authored
Jun 06, 2019
by
Adam Procter
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Implement DynElimination for DynSlice; simple test passing, but more needed
parent
f561c937
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
304 additions
and
0 deletions
+304
-0
dyn_elimination.cpp
src/ngraph/pass/dyn_elimination.cpp
+303
-0
dyn_elimination.hpp
src/ngraph/pass/dyn_elimination.hpp
+1
-0
No files found.
src/ngraph/pass/dyn_elimination.cpp
View file @
080d4f95
...
...
@@ -15,8 +15,11 @@
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
...
...
@@ -27,6 +30,7 @@ pass::DynElimination::DynElimination()
:
GraphRewrite
()
{
construct_transpose
();
construct_dyn_reshape
();
}
void
pass
::
DynElimination
::
construct_transpose
()
...
...
@@ -74,3 +78,302 @@ void pass::DynElimination::construct_transpose()
auto
transpose_matcher
=
make_shared
<
pattern
::
Matcher
>
(
transpose
,
"DynElimination.Transpose"
);
add_matcher
(
transpose_matcher
,
transpose_callback
,
all_pass_property_off
);
}
//
// We eliminate DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// vconst
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct
SlicePlan
{
// Parameters for the Slice
std
::
vector
<
int64_t
>
begins
;
std
::
vector
<
int64_t
>
ends
;
std
::
vector
<
int64_t
>
strides
;
// Shapes coming into, and going out of, the Reshape.
Shape
reshape_in_shape
;
Shape
reshape_out_shape
;
// Parameters for the Reverse
std
::
set
<
size_t
>
reverse_axes
;
};
static
SlicePlan
make_plan
(
const
Shape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
begins
,
const
std
::
vector
<
int64_t
>&
ends
,
const
std
::
vector
<
int64_t
>&
strides
,
const
AxisSet
&
lower_bounds_mask
,
const
AxisSet
&
upper_bounds_mask
,
const
AxisSet
&
new_axis_mask
,
const
AxisSet
&
shrink_axis_mask
,
const
AxisSet
&
ellipsis_mask
)
{
NGRAPH_CHECK
(
begins
.
size
()
==
ends
.
size
());
NGRAPH_CHECK
(
ends
.
size
()
==
strides
.
size
());
size_t
num_slice_indices
=
begins
.
size
();
size_t
num_real_axes
=
0
;
size_t
num_shrink_axes
=
0
;
size_t
num_new_axes
=
0
;
bool
ellipsis_found
=
false
;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for
(
size_t
i
=
0
;
i
<
num_slice_indices
;
i
++
)
{
if
(
ellipsis_mask
.
count
(
i
))
{
NGRAPH_CHECK
(
!
ellipsis_found
);
ellipsis_found
=
true
;
}
else
if
(
new_axis_mask
.
count
(
i
))
{
num_new_axes
++
;
}
else
{
if
(
shrink_axis_mask
.
count
(
i
))
{
num_shrink_axes
++
;
}
num_real_axes
++
;
}
}
NGRAPH_CHECK
(
num_real_axes
<=
input_shape
.
size
());
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t
ellipsis_size
=
input_shape
.
size
()
-
num_real_axes
;
// Initialize our slice plan.
SlicePlan
p
;
p
.
begins
=
std
::
vector
<
int64_t
>
(
num_real_axes
+
ellipsis_size
);
p
.
ends
=
std
::
vector
<
int64_t
>
(
num_real_axes
+
ellipsis_size
);
p
.
strides
=
std
::
vector
<
int64_t
>
(
num_real_axes
+
ellipsis_size
);
p
.
reshape_in_shape
=
Shape
(
num_real_axes
+
ellipsis_size
);
p
.
reshape_out_shape
=
Shape
(
num_new_axes
+
num_real_axes
+
ellipsis_size
-
num_shrink_axes
);
p
.
reverse_axes
=
AxisSet
{};
// Begin a maddeningly delicate loop to desugar the original slice specs.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t
i_in
=
0
;
size_t
i_out
=
0
;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto
expand_ellipsis
=
[
&
]()
{
for
(
size_t
i
=
0
;
i
<
ellipsis_size
;
i
++
)
{
p
.
begins
[
i_in
]
=
0
;
p
.
ends
[
i_in
]
=
int64_t
(
input_shape
[
i_in
]);
p
.
strides
[
i_in
]
=
1
;
p
.
reshape_in_shape
[
i_in
]
=
input_shape
[
i_in
];
p
.
reshape_out_shape
[
i_out
]
=
input_shape
[
i_in
];
i_in
++
;
i_out
++
;
}
};
for
(
size_t
i
=
0
;
i
<
num_slice_indices
;
i
++
)
{
// If this is a "newaxis", we throw a 1 into the final shape, but it
// will not be present in the intermediate shape and does not
// correspond to anything in the original shape.
if
(
new_axis_mask
.
count
(
i
))
{
p
.
reshape_out_shape
[
i_out
]
=
1
;
i_out
++
;
}
// If this is a "shrunken" axis, the intermediate shape will have a
// "1" here, but nothing will be there in the final shape.
else
if
(
shrink_axis_mask
.
count
(
i
))
{
int64_t
begin
=
begins
[
i
];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK
(
begin
>=
-
(
int64_t
(
input_shape
[
i_in
]))
&&
begin
<
int64_t
(
input_shape
[
i_in
]));
if
(
begin
<
0
)
{
begin
+=
int64_t
(
input_shape
[
i_in
]);
}
p
.
begins
[
i_in
]
=
begin
;
p
.
ends
[
i_in
]
=
begin
+
1
;
p
.
strides
[
i_in
]
=
1
;
p
.
reshape_in_shape
[
i_in
]
=
1
;
i_in
++
;
}
// If this is the ellipsis, expand it (see expand_ellipsis above for
// details).
else
if
(
ellipsis_mask
.
count
(
i
))
{
expand_ellipsis
();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool
is_reverse
=
strides
[
i
]
<
0
;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t
real_begin
=
begins
[
i
];
if
(
lower_bounds_mask
.
count
(
i
))
{
real_begin
=
(
is_reverse
?
int64_t
(
input_shape
[
i_in
]
-
1
)
:
0
);
}
else
if
(
real_begin
<
0
)
{
real_begin
+=
int64_t
(
input_shape
[
i_in
]);
}
int64_t
max_real_begin
=
int64_t
(
input_shape
[
i_in
])
-
(
is_reverse
?
1
:
0
);
real_begin
=
std
::
max
(
int64_t
(
0
),
std
::
min
(
max_real_begin
,
real_begin
));
// Adjust the ending for from-the-right indexing, and clip.
int64_t
real_end
=
ends
[
i
];
if
(
upper_bounds_mask
.
count
(
i
))
{
real_end
=
(
is_reverse
?
-
1
:
int64_t
(
input_shape
[
i_in
]));
}
else
if
(
real_end
<
0
)
{
real_end
+=
int64_t
(
input_shape
[
i_in
]);
}
int64_t
min_real_end
=
(
is_reverse
?
-
1
:
0
);
real_end
=
std
::
max
(
min_real_end
,
std
::
min
(
int64_t
(
input_shape
[
i_in
]),
real_end
));
// Adjust the stride for backwards slicing.
int64_t
real_stride
=
std
::
abs
(
strides
[
i
]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if
(
is_reverse
)
{
real_end
+=
std
::
max
(
int64_t
(
0
),
real_begin
-
real_end
-
1
)
%
real_stride
;
std
::
swap
(
real_begin
,
real_end
);
real_begin
++
;
real_end
++
;
p
.
reverse_axes
.
insert
(
i_out
);
}
// Compute output dimension.
size_t
dim
=
(
real_end
<=
real_begin
?
0
:
size_t
(
real_end
-
real_begin
-
1
)
/
size_t
(
real_stride
)
+
1
);
p
.
reshape_in_shape
[
i_in
]
=
dim
;
p
.
reshape_out_shape
[
i_out
]
=
dim
;
// Set up the begin/end/stride.
p
.
begins
[
i_in
]
=
real_begin
;
p
.
ends
[
i_in
]
=
real_end
;
p
.
strides
[
i_in
]
=
real_stride
;
i_in
++
;
i_out
++
;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if
(
!
ellipsis_found
)
{
expand_ellipsis
();
}
return
p
;
}
void
pass
::
DynElimination
::
construct_dyn_reshape
()
{
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
begins_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
ends_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
strides_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
dyn_slice_pat
=
make_shared
<
op
::
DynSlice
>
(
data_arg_label
,
begins_arg_label
,
ends_arg_label
,
strides_arg_label
,
AxisSet
{},
AxisSet
{},
AxisSet
{},
AxisSet
{},
AxisSet
{});
auto
dyn_slice_callback
=
[
data_arg_label
,
begins_arg_label
,
ends_arg_label
,
strides_arg_label
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
data_arg
=
pattern_map
[
data_arg_label
];
auto
begins_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
begins_arg_label
]);
auto
ends_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
ends_arg_label
]);
auto
strides_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
strides_arg_label
]);
auto
dyn_slice
=
static_pointer_cast
<
op
::
DynSlice
>
(
m
.
get_match_root
());
if
(
data_arg
->
get_output_partial_shape
(
0
).
is_dynamic
()
||
begins_arg
->
get_element_type
()
!=
element
::
i64
||
ends_arg
->
get_element_type
()
!=
element
::
i64
||
strides_arg
->
get_element_type
()
!=
element
::
i64
)
{
return
false
;
}
SlicePlan
p
=
make_plan
(
data_arg
->
get_output_shape
(
0
),
begins_arg
->
get_vector
<
int64_t
>
(),
ends_arg
->
get_vector
<
int64_t
>
(),
strides_arg
->
get_vector
<
int64_t
>
(),
dyn_slice
->
get_lower_bounds_mask
(),
dyn_slice
->
get_upper_bounds_mask
(),
dyn_slice
->
get_new_axis
(),
dyn_slice
->
get_shrink_axis
(),
dyn_slice
->
get_ellipsis_mask
());
shared_ptr
<
Node
>
replacement
=
make_shared
<
op
::
Slice
>
(
data_arg
,
Coordinate
(
p
.
begins
.
begin
(),
p
.
begins
.
end
()),
Coordinate
(
p
.
ends
.
begin
(),
p
.
ends
.
end
()),
Strides
(
p
.
strides
.
begin
(),
p
.
strides
.
end
()));
if
(
p
.
reshape_in_shape
!=
p
.
reshape_out_shape
)
{
replacement
=
make_shared
<
op
::
Reshape
>
(
replacement
,
ngraph
::
get_default_order
(
p
.
reshape_in_shape
),
p
.
reshape_out_shape
);
}
if
(
!
p
.
reverse_axes
.
empty
())
{
replacement
=
make_shared
<
op
::
Reverse
>
(
replacement
,
p
.
reverse_axes
);
}
replace_node
(
m
.
get_match_root
(),
replacement
);
return
true
;
};
auto
dyn_slice_matcher
=
make_shared
<
pattern
::
Matcher
>
(
dyn_slice_pat
,
"DynElimination.DynShape"
);
add_matcher
(
dyn_slice_matcher
,
dyn_slice_callback
,
all_pass_property_off
);
}
src/ngraph/pass/dyn_elimination.hpp
View file @
080d4f95
...
...
@@ -30,6 +30,7 @@ namespace ngraph
private
:
void
construct_transpose
();
void
construct_dyn_reshape
();
};
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment