Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e786fcfe
Commit
e786fcfe
authored
5 years ago
by
Ivan Tikhonov
Committed by
Michał Karzyński
5 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
TensorIterator: reshape support (#4038)
parent
c8988ca9
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
3 deletions
+53
-3
tensor_iterator.cpp
src/ngraph/op/tensor_iterator.cpp
+47
-3
specialize_function.cpp
src/ngraph/specialize_function.cpp
+6
-0
No files found.
src/ngraph/op/tensor_iterator.cpp
View file @
e786fcfe
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/specialize_function.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
...
@@ -220,7 +222,7 @@ void op::TensorIterator::revalidate_and_infer_types_for_body_ops()
...
@@ -220,7 +222,7 @@ void op::TensorIterator::revalidate_and_infer_types_for_body_ops()
std
::
stack
<
std
::
shared_ptr
<
Node
>
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>>
nodes_to_do
;
std
::
stack
<
std
::
shared_ptr
<
Node
>
,
std
::
vector
<
std
::
shared_ptr
<
Node
>>>
nodes_to_do
;
std
::
unordered_set
<
std
::
shared_ptr
<
Node
>>
nodes_done
;
std
::
unordered_set
<
std
::
shared_ptr
<
Node
>>
nodes_done
;
for
(
auto
r
:
m_body
->
get_results
())
for
(
const
auto
&
r
:
m_body
->
get_results
())
{
{
nodes_to_do
.
push
(
r
);
nodes_to_do
.
push
(
r
);
}
}
...
@@ -281,7 +283,7 @@ void op::TensorIterator::validate_and_infer_types()
...
@@ -281,7 +283,7 @@ void op::TensorIterator::validate_and_infer_types()
// Input
// Input
uint64_t
index_it
=
0
;
uint64_t
index_it
=
0
;
for
(
auto
input_description
:
m_input_descriptions
)
for
(
const
auto
&
input_description
:
m_input_descriptions
)
{
{
auto
index
=
input_description
->
m_input_index
;
auto
index
=
input_description
->
m_input_index
;
NODE_VALIDATION_CHECK
(
this
,
index
==
index_it
,
"Input_index not in order"
);
NODE_VALIDATION_CHECK
(
this
,
index
==
index_it
,
"Input_index not in order"
);
...
@@ -398,7 +400,7 @@ void op::TensorIterator::validate_and_infer_types()
...
@@ -398,7 +400,7 @@ void op::TensorIterator::validate_and_infer_types()
// Output
// Output
index_it
=
0
;
index_it
=
0
;
for
(
auto
output_description
:
m_output_descriptions
)
for
(
const
auto
&
output_description
:
m_output_descriptions
)
{
{
auto
index
=
output_description
->
m_output_index
;
auto
index
=
output_description
->
m_output_index
;
NODE_VALIDATION_CHECK
(
this
,
index
==
index_it
,
"Output_index not in order"
);
NODE_VALIDATION_CHECK
(
this
,
index
==
index_it
,
"Output_index not in order"
);
...
@@ -437,6 +439,48 @@ void op::TensorIterator::validate_and_infer_types()
...
@@ -437,6 +439,48 @@ void op::TensorIterator::validate_and_infer_types()
std
::
shared_ptr
<
Node
>
op
::
TensorIterator
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
std
::
shared_ptr
<
Node
>
op
::
TensorIterator
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
auto
op
=
make_shared
<
op
::
TensorIterator
>
(
as_output_vector
(
new_args
));
auto
op
=
make_shared
<
op
::
TensorIterator
>
(
as_output_vector
(
new_args
));
op
->
set_output_size
(
m_output_descriptions
.
size
());
std
::
vector
<::
ngraph
::
element
::
Type
>
types
(
m_body
->
get_parameters
().
size
());
std
::
vector
<::
ngraph
::
PartialShape
>
new_shapes
(
m_body
->
get_parameters
().
size
());
for
(
size_t
input_index
=
0
;
input_index
<
new_args
.
size
();
++
input_index
)
{
for
(
auto
&
input_description
:
m_input_descriptions
)
{
if
(
input_description
->
m_input_index
==
input_index
)
{
types
[
input_description
->
m_body_parameter_index
]
=
new_args
[
input_index
]
->
get_element_type
();
new_shapes
[
input_description
->
m_body_parameter_index
]
=
new_args
[
input_index
]
->
get_output_partial_shape
(
0
);
if
(
new_shapes
[
input_description
->
m_body_parameter_index
].
is_static
())
{
if
(
auto
slice_in
=
::
ngraph
::
as_type_ptr
<
ngraph
::
op
::
TensorIterator
::
SliceInputDescription
>
(
input_description
))
{
new_shapes
[
slice_in
->
m_body_parameter_index
][
slice_in
->
m_axis
]
=
slice_in
->
m_part_size
;
}
}
}
}
}
auto
func
=
std
::
make_shared
<
Function
>
(
m_body
->
get_results
(),
m_body
->
get_parameters
());
auto
spec_func
=
specialize_function
(
func
,
types
,
new_shapes
,
std
::
vector
<
void
*>
(
new_args
.
size
(),
nullptr
),
false
,
true
);
op
->
m_body
=
std
::
make_shared
<
BodyLambda
>
(
spec_func
->
get_results
(),
spec_func
->
get_parameters
());
// TODO: remove this code after the fix on the nGraph side (GetOutputElements)
::
ngraph
::
pass
::
GetOutputElementElimination
goe_elimination
;
for
(
const
auto
&
n
:
spec_func
->
get_ops
())
{
goe_elimination
.
run_on_node
(
n
);
}
for
(
auto
&
input_description
:
m_input_descriptions
)
for
(
auto
&
input_description
:
m_input_descriptions
)
{
{
op
->
m_input_descriptions
.
push_back
(
input_description
->
copy
());
op
->
m_input_descriptions
.
push_back
(
input_description
->
copy
());
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/specialize_function.cpp
View file @
e786fcfe
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "ngraph/specialize_function.hpp"
#include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp>
#include <pass/constant_folding.hpp>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/tensor_iterator.hpp"
using
namespace
ngraph
;
using
namespace
ngraph
;
...
@@ -84,6 +85,11 @@ std::shared_ptr<Function>
...
@@ -84,6 +85,11 @@ std::shared_ptr<Function>
else
else
{
{
m
[
old_node
.
get
()]
=
old_node
->
copy_with_new_inputs
(
new_args
);
m
[
old_node
.
get
()]
=
old_node
->
copy_with_new_inputs
(
new_args
);
// TODO: workaround for shape inference, delete it after fix
if
(
::
ngraph
::
as_type_ptr
<
ngraph
::
op
::
TensorIterator
>
(
m
[
old_node
.
get
()]))
{
m
[
old_node
.
get
()]
->
validate_and_infer_types
();
}
m
[
old_node
.
get
()]
->
get_rt_info
()
=
old_node
->
get_rt_info
();
m
[
old_node
.
get
()]
->
get_rt_info
()
=
old_node
->
get_rt_info
();
}
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment