Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8a346fb5
Unverified
Commit
8a346fb5
authored
Sep 05, 2019
by
Scott Cyphers
Committed by
GitHub
Sep 05, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cyphers/typename
parents
7c9269c2
bbe4735e
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
215 additions
and
5 deletions
+215
-5
mlir_subgraph_extraction.cpp
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.cpp
+7
-0
mlir_subgraph_extraction.hpp
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.hpp
+1
-0
function.cpp
src/ngraph/function.cpp
+12
-0
function.hpp
src/ngraph/function.hpp
+10
-0
graph_util.cpp
src/ngraph/graph_util.cpp
+35
-5
graph_util.hpp
src/ngraph/graph_util.hpp
+28
-0
CMakeLists.txt
test/CMakeLists.txt
+1
-0
replace_node.cpp
test/replace_node.cpp
+121
-0
No files found.
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.cpp
View file @
8a346fb5
...
...
@@ -132,6 +132,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
sanity_check
(
func
,
ck_nodes
);
#endif
clean_up
();
return
true
;
}
...
...
@@ -506,6 +507,12 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
return
false
;
}
void
MLIRSubgraphExtractionPass
::
clean_up
()
{
m_id_to_graph
.
clear
();
m_node_to_graph
.
clear
();
}
const
std
::
set
<
std
::
type_index
>
MLIRSubgraphExtractionPass
::
m_supported_ops
{
#define MLIR_OP(OP) TI(ngraph::op::OP),
#include "contrib/mlir/compiler/ops_supported.inc"
...
...
src/contrib/mlir/compiler/pass/mlir_subgraph_extraction.hpp
View file @
8a346fb5
...
...
@@ -127,6 +127,7 @@ namespace ngraph
NodeVector
build_ck_nodes
(
std
::
shared_ptr
<
Function
>
func
);
void
sanity_check
(
std
::
shared_ptr
<
Function
>
func
,
NodeVector
&
ck_nodes
);
void
clean_up
();
private
:
using
IDGraphMap
=
std
::
unordered_map
<
int
,
MLIRSubgraph
>
;
...
...
src/ngraph/function.cpp
View file @
8a346fb5
...
...
@@ -306,3 +306,15 @@ bool Function::is_dynamic() const
}
return
false
;
}
void
Function
::
replace_parameter
(
size_t
parameter_index
,
const
shared_ptr
<
op
::
Parameter
>&
parameter
)
{
NGRAPH_CHECK
(
parameter_index
<
m_parameters
.
size
(),
"replace_parameter(): Tried to replace parameter at index "
,
parameter_index
,
" but the function only has "
,
m_parameters
.
size
(),
" parameters."
);
replace_node
(
m_parameters
[
parameter_index
],
parameter
);
m_parameters
[
parameter_index
]
=
parameter
;
}
src/ngraph/function.hpp
View file @
8a346fb5
...
...
@@ -117,6 +117,16 @@ namespace ngraph
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool
is_dynamic
()
const
;
/// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
///
/// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
/// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
///
/// \param parameter_index The index of the parameter to replace.
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
void
replace_parameter
(
size_t
parameter_index
,
const
std
::
shared_ptr
<
op
::
Parameter
>&
parameter
);
protected
:
ResultVector
m_results
;
ParameterVector
m_parameters
;
...
...
src/ngraph/graph_util.cpp
View file @
8a346fb5
...
...
@@ -28,7 +28,6 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/provenance.hpp"
...
...
@@ -139,10 +138,12 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
throw
ngraph_error
(
"Result nodes cannot be replaced."
);
}
if
(
target
->
get_users
().
empty
())
{
throw
ngraph_error
(
"replacing an unreachable node"
);
}
NGRAPH_CHECK
(
!
target
->
get_users
().
empty
(),
"Attempted to replace unreachable node '"
,
*
target
,
"'. Replacement: '"
,
*
replacement
,
"'"
);
// Fix input/output descriptors
NGRAPH_CHECK
(
target
->
get_output_size
()
==
replacement
->
get_output_size
());
...
...
@@ -179,6 +180,35 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
target
->
clear_control_dependents
();
}
void
ngraph
::
replace_nodes
(
const
std
::
shared_ptr
<
Function
>&
f
,
const
unordered_map
<
shared_ptr
<
op
::
Parameter
>
,
shared_ptr
<
op
::
Parameter
>>&
parameter_replacement_map
,
const
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>>&
body_replacement_map
)
{
auto
&
params
=
f
->
get_parameters
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
i
++
)
{
if
(
parameter_replacement_map
.
count
(
params
[
i
])
!=
0
&&
parameter_replacement_map
.
at
(
params
[
i
])
!=
params
[
i
])
{
f
->
replace_parameter
(
i
,
parameter_replacement_map
.
at
(
params
[
i
]));
}
}
for
(
auto
&
kv
:
body_replacement_map
)
{
auto
&
k
=
kv
.
first
;
auto
&
v
=
kv
.
second
;
if
(
k
!=
v
)
{
f
->
replace_node
(
k
,
v
);
}
}
}
// Check if all paths from X to a result go through Y
bool
ngraph
::
is_post_dominated
(
Node
*
X
,
Node
*
Y
)
{
...
...
src/ngraph/graph_util.hpp
View file @
8a346fb5
...
...
@@ -214,6 +214,34 @@ namespace ngraph
/// replace_node(N, M);
void
replace_node
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
);
/// \brief Replace multiple nodes in a function.
/// \param f Function where replacement is taking place.
/// \param parameter_replacement_map A mapping from parameter shared pointers to parameter
/// shared pointers. For each pair (k,v) in the map, parameter
/// k is replaced by parameter v, except if k==v or k is not a
/// parameter bound by f, in which case the pair (k,v) is
/// ignored.
/// \param body_replacement_map A mapping from node shared pointers to node shared pointers.
/// For each pair (k,v) in the map, node k is replaced by node v,
/// except if k==v, the pair (k,v) is ignored.
/// Note that if k is a parameter, its users will be redirected to
/// v, but k will _not_ be replaced in the function's parameter
/// list.
///
/// Limitations:
///
/// - No check is made that the replaced nodes in `parameter_replacement_map` are actually
/// among the bound parameters of `f`. (If a parameter appears in the map that is not
/// bound by `f`, it will be silently ignored.)
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
/// `body_replacement_map`, behavior is unspecified.
void
replace_nodes
(
const
std
::
shared_ptr
<
Function
>&
f
,
const
std
::
unordered_map
<
std
::
shared_ptr
<
op
::
Parameter
>
,
std
::
shared_ptr
<
op
::
Parameter
>>&
parameter_replacement_map
,
const
std
::
unordered_map
<
std
::
shared_ptr
<
Node
>
,
std
::
shared_ptr
<
Node
>>&
body_replacement_map
);
NodeVector
find_common_args
(
std
::
shared_ptr
<
Node
>
target
,
std
::
shared_ptr
<
Node
>
replacement
);
/// Topological sort of nodes needed to compute root_nodes
...
...
test/CMakeLists.txt
View file @
8a346fb5
...
...
@@ -73,6 +73,7 @@ set(SRC
pass_shape_relevance.cpp
pattern.cpp
provenance.cpp
replace_node.cpp
reshape_elimination.cpp
reshape_sinking.cpp
shape.cpp
...
...
test/replace_node.cpp
0 → 100644
View file @
8a346fb5
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using
namespace
std
;
using
namespace
ngraph
;
//
// Graph before (params in [] brackets, constants in () parens, results in {} braces):
//
// [x] [y] [z]
// \ / |
// Add (k) |
// \ / |
// Mul** |
// \ /
// Sub
// |
// {r}
//
// Param substitutions:
//
// [x] -> [x']
//
// Body substitutions:
//
// (k) -> (k')
// [y] -> (k'')
// [z] -> [x'] + **
//
// After replacement:
//
// [x']---------
// | |
// | (k'') | [z] and [y] is still there, but dead
// \ / |
// Add (k') |
// \ / |
// Mul |
// \ /
// Sub ***
// |
// {r}
//
TEST
(
replace_node
,
replace_nodes
)
{
auto
x
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
});
auto
y
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
});
auto
z
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
});
auto
add
=
x
+
y
;
auto
k
=
make_shared
<
op
::
Constant
>
(
element
::
f32
,
Shape
{
2
},
vector
<
float
>
{
1
,
2
});
auto
mul
=
add
*
k
;
auto
sub
=
mul
-
z
;
auto
f
=
make_shared
<
Function
>
(
NodeVector
{
sub
},
ParameterVector
{
x
,
y
,
z
});
unordered_map
<
shared_ptr
<
op
::
Parameter
>
,
shared_ptr
<
op
::
Parameter
>>
parameter_replacement_map
;
auto
x_replacement
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
});
parameter_replacement_map
[
x
]
=
x_replacement
;
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
Node
>>
body_replacement_map
;
auto
y_replacement
=
make_shared
<
op
::
Constant
>
(
element
::
f32
,
Shape
{
2
},
vector
<
float
>
{
3
,
4
});
auto
k_replacement
=
make_shared
<
op
::
Constant
>
(
element
::
f32
,
Shape
{
2
},
vector
<
float
>
{
5
,
6
});
auto
z_replacement
=
x_replacement
+
mul
;
body_replacement_map
[
y
]
=
y_replacement
;
body_replacement_map
[
k
]
=
k_replacement
;
body_replacement_map
[
z
]
=
z_replacement
;
replace_nodes
(
f
,
parameter_replacement_map
,
body_replacement_map
);
// Should still have three params.
ASSERT_EQ
(
f
->
get_parameters
().
size
(),
3
);
// The three params be {x_replacement, y, z}.
ASSERT_EQ
(
f
->
get_parameters
()[
0
],
x_replacement
);
ASSERT_EQ
(
f
->
get_parameters
()[
1
],
y
);
ASSERT_EQ
(
f
->
get_parameters
()[
2
],
z
);
// y, z should be dead.
ASSERT_EQ
(
y
->
get_users
(
true
).
size
(),
0
);
ASSERT_EQ
(
z
->
get_users
(
true
).
size
(),
0
);
// Should still have one result.
ASSERT_EQ
(
f
->
get_results
().
size
(),
1
);
// Result node should be sub (unchanged).
ASSERT_EQ
(
f
->
get_results
()[
0
]
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
(),
sub
);
// sub's arguments should be mul (unchanged) and z_replacement.
ASSERT_EQ
(
sub
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
(),
mul
);
ASSERT_EQ
(
sub
->
input
(
1
).
get_source_output
().
get_node_shared_ptr
(),
z_replacement
);
// mul's arguments should be add (unchanged) and k_replacement.
ASSERT_EQ
(
mul
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
(),
add
);
ASSERT_EQ
(
mul
->
input
(
1
).
get_source_output
().
get_node_shared_ptr
(),
k_replacement
);
// add's arguments should be x_replacement and y_replacement.
ASSERT_EQ
(
add
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
(),
x_replacement
);
ASSERT_EQ
(
add
->
input
(
1
).
get_source_output
().
get_node_shared_ptr
(),
y_replacement
);
// z_replacement's arguments should be x_replacement and mul.
ASSERT_EQ
(
z_replacement
->
input
(
0
).
get_source_output
().
get_node_shared_ptr
(),
x_replacement
);
ASSERT_EQ
(
z_replacement
->
input
(
1
).
get_source_output
().
get_node_shared_ptr
(),
mul
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment