Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
94f4dfec
Commit
94f4dfec
authored
May 17, 2019
by
Adam Procter
Committed by
Robert Kimball
May 17, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Remove ShapeSpecialization pass and attendant as_constants machinery (#2948)
parent
d99ac8ce
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
0 additions
and
475 deletions
+0
-475
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-2
node.hpp
src/ngraph/node.hpp
+0
-16
concat.cpp
src/ngraph/op/concat.cpp
+0
-53
concat.hpp
src/ngraph/op/concat.hpp
+0
-2
shape_of.cpp
src/ngraph/op/experimental/shape_of.cpp
+0
-13
shape_of.hpp
src/ngraph/op/experimental/shape_of.hpp
+0
-3
shape_specialization.cpp
src/ngraph/pass/shape_specialization.cpp
+0
-155
shape_specialization.hpp
src/ngraph/pass/shape_specialization.hpp
+0
-36
CMakeLists.txt
test/CMakeLists.txt
+0
-1
pass_shape_specialization.cpp
test/pass_shape_specialization.cpp
+0
-194
No files found.
src/ngraph/CMakeLists.txt
View file @
94f4dfec
...
...
@@ -390,8 +390,6 @@ set (SRC
pass/serialize.hpp
pass/shape_relevance.cpp
pass/shape_relevance.hpp
pass/shape_specialization.cpp
pass/shape_specialization.hpp
pass/validate_graph.cpp
pass/validate_graph.hpp
pass/visualize_tree.cpp
...
...
src/ngraph/node.hpp
View file @
94f4dfec
...
...
@@ -113,22 +113,6 @@ namespace ngraph
// Called after transition
void
delayed_validate_and_infer_types
();
/// \brief Produce a vector of constant nodes (one for each of this node's outputs) that
/// can replace this node's outputs. May return an empty vector to signal that
/// conversion to constants is not possible or not supported.
/// \returns If conversion is successful, a vector of op::Constant nodes, corresponding
/// to this node's outputs in order. If unsuccessful, an empty vector.
///
/// Conversion does not have to be complete. That means that subclasses *may* override
/// as_constants, but do not have to. It is allowed for as_constants to return an empty
/// vector even in cases where the output values are statically computable. Thus, any user
/// of as_constants must allow for the possibility that conversion will fail (i.e.,
/// as_constants will return {}).
///
/// Conversion must be sound. That means that if as_constants returns a non-empty vector,
/// the value of each constant in the vector must be exactly the value that would have
/// been returned for the corresponding output at runtime.
virtual
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
as_constants
()
const
{
return
{};
}
/// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
/// The class name, must not contain spaces as it is used for codegen.
/// \returns A const reference to the node's type name
...
...
src/ngraph/op/concat.cpp
View file @
94f4dfec
...
...
@@ -17,7 +17,6 @@
#include <memory>
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/slice.hpp"
using
namespace
std
;
...
...
@@ -93,58 +92,6 @@ shared_ptr<Node> op::Concat::copy_with_new_args(const NodeVector& new_args) cons
return
make_shared
<
Concat
>
(
new_args
,
m_concatenation_axis
);
}
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
op
::
Concat
::
as_constants
()
const
{
if
(
get_concatenation_axis
()
!=
0
)
{
return
{};
}
size_t
total_elements
=
0
;
for
(
size_t
i
=
0
;
i
<
get_input_size
();
i
++
)
{
//
// For the time being we will only support int64 here, since that's all that's needed for
// static shape propagation.
//
if
(
get_input_element_type
(
i
)
!=
element
::
i64
)
{
return
{};
}
if
(
!
(
get_argument
(
i
)
->
is_constant
()))
{
return
{};
}
if
(
get_input_shape
(
i
).
size
()
!=
1
)
{
return
{};
}
total_elements
+=
shape_size
(
get_input_shape
(
i
));
}
std
::
vector
<
int64_t
>
values
(
total_elements
);
size_t
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
get_input_size
();
i
++
)
{
auto
const_node
=
static_pointer_cast
<
op
::
Constant
>
(
get_argument
(
i
));
// A little extra paranoia ahead of the memcpy.
NGRAPH_CHECK
(
get_input_shape
(
i
)
==
const_node
->
get_shape
()
&&
const_node
->
get_output_element_type
(
0
)
==
element
::
i64
);
// This memcpy should be safe, because values was initialized to have space for
// sum(0 <= j < num_inputs)(shape_size(get_input_shape(j))) elements, and pos is
// sum(0 <= j < i)(shape_size(get_input_shape(j))).
memcpy
(
values
.
data
()
+
pos
,
const_node
->
get_data_ptr
(),
shape_size
(
const_node
->
get_shape
())
*
sizeof
(
int64_t
));
pos
+=
shape_size
(
const_node
->
get_shape
());
}
return
{
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
total_elements
},
values
)};
}
void
op
::
Concat
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
auto
delta
=
deltas
.
at
(
0
);
...
...
src/ngraph/op/concat.hpp
View file @
94f4dfec
...
...
@@ -39,8 +39,6 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
as_constants
()
const
override
;
/// \return The concatenation axis.
size_t
get_concatenation_axis
()
const
{
return
m_concatenation_axis
;
}
protected
:
...
...
src/ngraph/op/experimental/shape_of.cpp
View file @
94f4dfec
...
...
@@ -32,19 +32,6 @@ void op::ShapeOf::validate_and_infer_types()
set_output_type
(
0
,
element
::
i64
,
PartialShape
{
get_input_partial_shape
(
0
).
rank
()});
}
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
op
::
ShapeOf
::
as_constants
()
const
{
if
(
get_input_partial_shape
(
0
).
is_static
())
{
return
{
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
get_input_shape
(
0
).
size
()},
get_input_shape
(
0
))};
}
else
{
return
{};
}
}
shared_ptr
<
Node
>
op
::
ShapeOf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/experimental/shape_of.hpp
View file @
94f4dfec
...
...
@@ -16,7 +16,6 @@
#pragma once
#include "ngraph/op/constant.hpp"
#include "ngraph/op/op.hpp"
namespace
ngraph
...
...
@@ -33,8 +32,6 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
as_constants
()
const
override
;
protected
:
void
validate_and_infer_types
()
override
;
};
...
...
src/ngraph/pass/shape_specialization.cpp
deleted
100644 → 0
View file @
d99ac8ce
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/shape_specialization.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
using
namespace
ngraph
;
//
// The shape specialization pass transforms a function by replacing, wherever possible, all
// shape-relevant inputs with constants. For example,
//
//
// _________
// | Param |
// | 2x2x3 |
// |_________|
// |
// ____|____
// | 0 |
// | ShapeOf |
// |_________|
// | |
// __|_____|___
// | 0 1 |
// | DynReshape |
// |____________|
// |
//
// (Where 0 is the data input and 1 is the shape input) would be replaced with:
//
// _____________
// | |
// | Constant |
// | val=[2,2,3] |
// |_____________|
// | |
// __|_____|___
// | 0 1 |
// | DynReshape |
// |____________|
// |
//
// Note that replacement will only be attempted on shape-relevant inputs, and will only be
// successful if the input's value is entirely determined by nodes that can be converted with
// as_constants().
//
bool
pass
::
ShapeSpecialization
::
run_on_function
(
std
::
shared_ptr
<
Function
>
f
)
{
// TODO(amprocte): We are probably reinventing the wheel with the graph traversal here; the
// reason is that we need to cut the traversal short in cases where input values are
// irrelevant. See if there is a way to reduce this duplication.
// Set of nodes that must be evaluated to determine the value of shape-relevant inputs.
std
::
set
<
Node
*>
shape_determinants
;
// Step 1: Find root nodes (these are nodes with an output connected to a shape-relevant
// input).
for
(
auto
&
n
:
f
->
get_ops
())
{
for
(
auto
&
output
:
n
->
outputs
())
{
for
(
auto
&
input
:
output
.
get_target_inputs
())
{
if
(
input
.
get_is_relevant_to_shapes
())
{
shape_determinants
.
insert
(
n
.
get
());
break
;
}
}
}
}
// Step 2: Find all shape determinants. This is the transitive closure of R, where n1 R n2
// iff there is a data flow edge from n2 to n1 and that data flow edge is not
// value-irrelevant.
{
std
::
list
<
Node
*>
to_visit
{
shape_determinants
.
begin
(),
shape_determinants
.
end
()};
std
::
set
<
Node
*>
already_visited
;
while
(
!
to_visit
.
empty
())
{
auto
node
=
to_visit
.
front
();
to_visit
.
pop_front
();
if
(
already_visited
.
count
(
node
)
>
0
)
{
continue
;
}
shape_determinants
.
insert
(
node
);
already_visited
.
insert
(
node
);
for
(
size_t
i
=
0
;
i
<
node
->
get_input_size
();
i
++
)
{
if
(
!
node
->
input
(
i
).
get_is_relevant_to_values
())
{
continue
;
}
auto
source_node
=
node
->
input
(
i
).
get_source_output
().
get_node
();
if
(
already_visited
.
count
(
source_node
)
==
0
)
{
to_visit
.
push_front
(
source_node
);
}
}
}
}
// Step 3: For each shape determinant in topological order, try to replace the determinant
// with constants.
bool
changes_made
=
false
;
for
(
auto
n
:
f
->
get_ordered_ops
())
{
if
(
shape_determinants
.
count
(
n
.
get
())
>
0
)
{
std
::
vector
<
std
::
shared_ptr
<
op
::
Constant
>>
replacement_constants
=
n
->
as_constants
();
if
(
replacement_constants
.
size
()
>
0
)
{
NGRAPH_CHECK
(
n
->
get_output_size
()
==
replacement_constants
.
size
());
for
(
size_t
i
=
0
;
i
<
n
->
get_output_size
();
i
++
)
{
NGRAPH_CHECK
(
n
->
get_output_partial_shape
(
i
).
relaxes
(
replacement_constants
[
i
]
->
get_output_partial_shape
(
0
)));
NGRAPH_CHECK
(
n
->
get_output_element_type
(
i
).
is_dynamic
()
||
n
->
get_output_element_type
(
i
)
==
replacement_constants
[
i
]
->
get_output_element_type
(
0
));
for
(
auto
&
input
:
n
->
output
(
i
).
get_target_inputs
())
{
input
.
replace_source_output
(
replacement_constants
.
at
(
i
)
->
output
(
0
));
changes_made
=
true
;
}
}
}
}
}
return
changes_made
;
}
src/ngraph/pass/shape_specialization.hpp
deleted
100644 → 0
View file @
d99ac8ce
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace
ngraph
{
namespace
pass
{
class
ShapeSpecialization
:
public
FunctionPass
{
public
:
ShapeSpecialization
()
:
FunctionPass
()
{
set_property
(
PassProperty
::
CHANGE_DYNAMIC_STATE
,
true
);
}
virtual
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
override
;
};
}
}
test/CMakeLists.txt
View file @
94f4dfec
...
...
@@ -57,7 +57,6 @@ set(SRC
pass_manager.cpp
pass_memory_layout.cpp
pass_shape_relevance.cpp
pass_shape_specialization.cpp
pattern.cpp
provenance.cpp
reshape_elimination.cpp
...
...
test/pass_shape_specialization.cpp
deleted
100644 → 0
View file @
d99ac8ce
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/shape_specialization.hpp"
using
namespace
ngraph
;
using
namespace
std
;
TEST
(
shape_specialization
,
as_constants_shape_of
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
Shape
{
2
,
4
,
6
,
8
});
auto
shape_of
=
make_shared
<
op
::
ShapeOf
>
(
param
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
shape_of
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
1
);
ASSERT_EQ
(
replacements
[
0
]
->
get_shape
(),
Shape
{
4
});
ASSERT_EQ
(
replacements
[
0
]
->
get_element_type
(),
element
::
i64
);
ASSERT_EQ
(
replacements
[
0
]
->
get_vector
<
int64_t
>
(),
(
vector
<
int64_t
>
{
2
,
4
,
6
,
8
}));
}
TEST
(
shape_specialization
,
specialization_pass_shape_of_transpose
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
Shape
{
4
,
6
});
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
Shape
{
1
,
0
});
auto
shape_of
=
make_shared
<
op
::
ShapeOf
>
(
param1
);
auto
transpose
=
make_shared
<
op
::
Transpose
>
(
param0
,
shape_of
);
auto
f
=
make_shared
<
Function
>
(
transpose
,
ParameterVector
{
param0
,
param1
});
pass
::
Manager
manager
;
manager
.
register_pass
<
pass
::
ShapeSpecialization
>
();
manager
.
run_passes
(
f
);
auto
transpose_after
=
dynamic_pointer_cast
<
op
::
Transpose
>
(
f
->
get_results
().
at
(
0
)
->
get_argument
(
0
));
ASSERT_NE
(
transpose_after
,
nullptr
);
auto
constant_after
=
dynamic_pointer_cast
<
op
::
Constant
>
(
transpose_after
->
get_argument
(
1
));
ASSERT_NE
(
constant_after
,
nullptr
);
ASSERT_EQ
(
constant_after
->
get_shape
(),
Shape
{
2
});
ASSERT_EQ
(
constant_after
->
get_element_type
(),
element
::
i64
);
ASSERT_EQ
(
constant_after
->
get_vector
<
int64_t
>
(),
(
vector
<
int64_t
>
{
1
,
0
}));
}
TEST
(
shape_specialization
,
as_constants_concat
)
{
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
4
},
{
1
,
2
,
3
,
4
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
3
},
{
2
,
5
,
1
});
auto
k2
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
0
},
std
::
vector
<
int64_t
>
{});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k0
,
k1
,
k2
},
0
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
concat
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
1
);
ASSERT_EQ
(
replacements
[
0
]
->
get_shape
(),
Shape
{
7
});
ASSERT_EQ
(
replacements
[
0
]
->
get_element_type
(),
element
::
i64
);
ASSERT_EQ
(
replacements
[
0
]
->
get_vector
<
int64_t
>
(),
(
vector
<
int64_t
>
{
1
,
2
,
3
,
4
,
2
,
5
,
1
}));
}
TEST
(
shape_specialization
,
as_constants_concat_noni64
)
{
auto
k0
=
op
::
Constant
::
create
(
element
::
i32
,
Shape
{
4
},
{
1
,
2
,
3
,
4
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i32
,
Shape
{
3
},
{
2
,
5
,
1
});
auto
k2
=
op
::
Constant
::
create
(
element
::
i32
,
Shape
{
0
},
std
::
vector
<
int32_t
>
{});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k0
,
k1
,
k2
},
0
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
concat
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
0
);
}
TEST
(
shape_specialization
,
as_constants_concat_nonvec_dim0
)
{
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
2
},
{
1
,
2
,
3
,
4
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
1
,
2
},
{
2
,
5
});
auto
k2
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
0
,
2
},
std
::
vector
<
int64_t
>
{});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k0
,
k1
,
k2
},
0
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
concat
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
0
);
}
TEST
(
shape_specialization
,
as_constants_concat_nonvec_dim1
)
{
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
2
},
{
1
,
2
,
3
,
4
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
1
},
{
2
,
5
});
auto
k2
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
0
},
std
::
vector
<
int64_t
>
{});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k0
,
k1
,
k2
},
1
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
concat
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
0
);
}
TEST
(
shape_specialization
,
as_constants_concat_nonconst
)
{
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
2
},
{
1
,
2
,
3
,
4
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
2
,
2
},
{
2
,
5
,
2
,
5
});
auto
add
=
k0
+
k1
;
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k0
,
k1
,
add
},
0
);
vector
<
shared_ptr
<
op
::
Constant
>>
replacements
=
concat
->
as_constants
();
ASSERT_EQ
(
replacements
.
size
(),
0
);
}
TEST
(
shape_specialization
,
specialization_pass_concat_transpose
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
Shape
{
4
,
6
});
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
1
},
{
0
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
1
},
{
1
});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k1
,
k0
},
0
);
auto
transpose
=
make_shared
<
op
::
Transpose
>
(
param0
,
concat
);
auto
f
=
make_shared
<
Function
>
(
transpose
,
ParameterVector
{
param0
});
pass
::
Manager
manager
;
manager
.
register_pass
<
pass
::
ShapeSpecialization
>
();
manager
.
run_passes
(
f
);
auto
transpose_after
=
dynamic_pointer_cast
<
op
::
Transpose
>
(
f
->
get_results
().
at
(
0
)
->
get_argument
(
0
));
ASSERT_NE
(
transpose_after
,
nullptr
);
auto
constant_after
=
dynamic_pointer_cast
<
op
::
Constant
>
(
transpose_after
->
get_argument
(
1
));
ASSERT_NE
(
constant_after
,
nullptr
);
ASSERT_EQ
(
constant_after
->
get_shape
(),
Shape
{
2
});
ASSERT_EQ
(
constant_after
->
get_element_type
(),
element
::
i64
);
ASSERT_EQ
(
constant_after
->
get_vector
<
int64_t
>
(),
(
vector
<
int64_t
>
{
1
,
0
}));
}
// Slight variation on the above test, where the "Concat" does not already have constants going
// into it. (The permutation is Concat(Const<1>,Concat(Const<>,Const<0>)) rather than simply
// Concat(Const<1>,Const<0>).)
TEST
(
shape_specialization
,
specialization_pass_add_concat_transpose
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
Shape
{
4
,
6
});
auto
k0
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
1
},
{
0
});
auto
k1
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
1
},
{
1
});
auto
kempty
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{
0
},
vector
<
int64_t
>
{});
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
k1
,
make_shared
<
op
::
Concat
>
(
NodeVector
{
kempty
,
k0
},
0
)},
0
);
auto
transpose
=
make_shared
<
op
::
Transpose
>
(
param0
,
concat
);
auto
f
=
make_shared
<
Function
>
(
transpose
,
ParameterVector
{
param0
});
pass
::
Manager
manager
;
manager
.
register_pass
<
pass
::
ShapeSpecialization
>
();
manager
.
run_passes
(
f
);
auto
transpose_after
=
dynamic_pointer_cast
<
op
::
Transpose
>
(
f
->
get_results
().
at
(
0
)
->
get_argument
(
0
));
ASSERT_NE
(
transpose_after
,
nullptr
);
auto
constant_after
=
dynamic_pointer_cast
<
op
::
Constant
>
(
transpose_after
->
get_argument
(
1
));
ASSERT_NE
(
constant_after
,
nullptr
);
ASSERT_EQ
(
constant_after
->
get_shape
(),
Shape
{
2
});
ASSERT_EQ
(
constant_after
->
get_element_type
(),
element
::
i64
);
ASSERT_EQ
(
constant_after
->
get_vector
<
int64_t
>
(),
(
vector
<
int64_t
>
{
1
,
0
}));
}
TEST
(
shape_specialization
,
pass_property
)
{
auto
pass
=
std
::
make_shared
<
ngraph
::
pass
::
ShapeSpecialization
>
();
ASSERT_EQ
(
false
,
pass
->
get_property
(
pass
::
PassProperty
::
REQUIRE_STATIC_SHAPE
));
ASSERT_EQ
(
true
,
pass
->
get_property
(
pass
::
PassProperty
::
CHANGE_DYNAMIC_STATE
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment