Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
cb431144
Unverified
Commit
cb431144
authored
Jul 10, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 10, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into gauri/macos_cast
parents
dba33d02
13bdf0ef
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
356 additions
and
33 deletions
+356
-33
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-2
cpu_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
+137
-0
cpu_fusion.hpp
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
+2
-0
CMakeLists.txt
src/ngraph/runtime/plaidml/CMakeLists.txt
+1
-0
plaidml_compiler.cpp
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
+3
-4
plaidml_config.cpp
src/ngraph/runtime/plaidml/plaidml_config.cpp
+6
-0
plaidml_ops_replicate.cpp
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
+1
-1
plaidml_pass_implicit_broadcast.cpp
...graph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
+22
-3
plaidml_pass_lower_convolutions.cpp
...graph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
+4
-4
plaidml_pass_prefix_reshape_elimination.cpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.cpp
+35
-8
plaidml_pass_prefix_reshape_elimination.hpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp
+11
-7
plaidml_pass_replicate_combination.cpp
...ph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
+3
-3
plaidml_pass_replicate_elision.cpp
...ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
+1
-1
cpu_fusion.cpp
test/cpu_fusion.cpp
+130
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
cb431144
...
...
@@ -419,8 +419,6 @@ set (SRC
pass/pass.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.cpp
View file @
cb431144
...
...
@@ -650,6 +650,143 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_batch_norm_relu_global_sta
this
->
add_matcher
(
m
,
callback
);
}
// graph before this fusion:
// input mean var gamma beta broadcast1_input broadcast2_input
// \ \ | / / / \
// BatchNormInference Broadcast1 Broadcast2
// \ / /
// Multiply /
// \ /
// Add
// |
// Relu
//
//
// graph after this fusion:
// input mean var gamma broadcast1_input beta broadcast2_input
// \ \ | \ / \ / /
// \ \ | Mulitply1 Multiply2 /
// \ \ | / \ /
// \ \ | / newAdd
// \ \| / /
// BatchNormInferenceRelu
//
// Multiply1, Multiply2, and newAdd operate on vectors while Multiply an Add operate on multi-dimensional matrices.
// Multiply1, Multiply2, and newAdd may be folded away with constant folding pass later.
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_batch_norm_infer_relu_with_multiply_add
()
{
auto
input_shape
=
Shape
{
1
,
3
,
2
,
2
};
auto
input
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
input_shape
);
auto
mean_shape
=
Shape
{
3
};
auto
mean
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
mean_shape
);
auto
var_shape
=
Shape
{
3
};
auto
var
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
var_shape
);
auto
gamma_shape
=
Shape
{
3
};
auto
gamma
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
gamma_shape
);
auto
beta_shape
=
Shape
{
3
};
auto
beta
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
beta_shape
);
double
eps
=
0.001
;
auto
bn
=
std
::
make_shared
<
ngraph
::
op
::
BatchNormInference
>
(
eps
,
gamma
,
beta
,
input
,
mean
,
var
);
auto
bn_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
bn
,
nullptr
,
NodeVector
{
bn
});
auto
broadcast1_input
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
gamma_shape
);
auto
broadcast1
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast1_input
,
input_shape
,
AxisSet
{
0
,
2
,
3
});
auto
broadcast1_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
broadcast1
,
nullptr
,
NodeVector
{
broadcast1
});
auto
multiply
=
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
bn_label
,
broadcast1_label
);
auto
multi_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
multiply
,
nullptr
,
NodeVector
{
multiply
});
auto
broadcast2_input
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
gamma_shape
);
auto
broadcast2
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast2_input
,
input_shape
,
AxisSet
{
0
,
2
,
3
});
auto
broadcast2_label
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
broadcast2
,
nullptr
,
NodeVector
{
broadcast2
});
auto
add
=
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
multi_label
,
broadcast2_label
);
auto
prelu
=
std
::
make_shared
<
ngraph
::
op
::
Relu
>
(
add
);
auto
callback
=
[
input
,
mean
,
var
,
gamma
,
beta
,
bn_label
,
multi_label
,
broadcast1_input
,
broadcast2_input
,
broadcast1_label
,
broadcast2_label
](
pattern
::
Matcher
&
m
)
{
NGRAPH_DEBUG
<<
"In callback for construct_batch_norm_infer_relu_with_multi_add against node = "
<<
m
.
get_match_root
()
->
get_name
();
auto
pattern_map
=
m
.
get_pattern_map
();
auto
bn_match
=
pattern_map
[
bn_label
];
if
(
bn_match
->
get_users
().
size
()
>
1
)
{
NGRAPH_DEBUG
<<
"Multiply isn't the only user of BatchNorm's output"
;
return
false
;
}
auto
multi_match
=
pattern_map
[
multi_label
];
if
(
multi_match
->
get_users
().
size
()
>
1
)
{
NGRAPH_DEBUG
<<
"Add isn't the only user of Multiply's output"
;
return
false
;
}
std
::
vector
<
size_t
>
vec
{
0
};
for
(
auto
i
=
2
;
i
<
pattern_map
[
input
]
->
output
(
0
).
get_shape
().
size
();
i
++
)
{
vec
.
push_back
(
i
);
}
AxisSet
axisSet
{
vec
};
if
(
std
::
static_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
pattern_map
[
broadcast1_label
])
->
get_broadcast_axes
()
!=
axisSet
||
std
::
static_pointer_cast
<
ngraph
::
op
::
Broadcast
>
(
pattern_map
[
broadcast2_label
])
->
get_broadcast_axes
()
!=
axisSet
)
{
NGRAPH_DEBUG
<<
"Broadcast axes is not {0, 2, ...}"
;
return
false
;
}
auto
new_gamma
=
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
pattern_map
[
gamma
],
pattern_map
[
broadcast1_input
]);
auto
new_multi
=
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
pattern_map
[
beta
],
pattern_map
[
broadcast1_input
]);
auto
new_beta
=
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
new_multi
,
pattern_map
[
broadcast2_input
]);
std
::
shared_ptr
<
Node
>
bn_relu
;
if
(
auto
bn_inference
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
BatchNormInference
>
(
bn_match
))
{
if
(
!
mkldnn_utils
::
can_use_mkldnn_batchnorm_fprop
(
bn_inference
.
get
()))
{
return
false
;
}
bn_relu
=
std
::
make_shared
<
ngraph
::
op
::
BatchNormInferenceRelu
>
(
bn_inference
->
get_eps_value
(),
new_gamma
,
new_beta
,
pattern_map
[
input
],
pattern_map
[
mean
],
pattern_map
[
var
]);
}
if
(
bn_relu
)
{
ngraph
::
replace_node
(
m
.
get_match_root
(),
bn_relu
);
return
true
;
}
return
false
;
};
auto
m
=
std
::
make_shared
<
ngraph
::
pattern
::
Matcher
>
(
prelu
,
"CPUFusion.BatchNormInferReluWithMultiAdd"
);
this
->
add_matcher
(
m
,
callback
);
}
void
ngraph
::
runtime
::
cpu
::
pass
::
CPUFusion
::
construct_conv_relu
()
{
Shape
shape
{
2
,
2
,
1
,
1
};
...
...
src/ngraph/runtime/cpu/pass/cpu_fusion.hpp
View file @
cb431144
...
...
@@ -78,6 +78,7 @@ public:
construct_deconvolution_affine_folding_relu
();
}
construct_dropout
();
construct_batch_norm_infer_relu_with_multiply_add
();
}
}
...
...
@@ -90,6 +91,7 @@ private:
void
construct_sigmoid_multiply
();
void
construct_batch_norm_relu
();
void
construct_batch_norm_relu_global_stats
();
void
construct_batch_norm_infer_relu_with_multiply_add
();
void
construct_conv_relu
();
void
construct_conv_bias_relu
();
void
construct_conv_bias_add
();
...
...
src/ngraph/runtime/plaidml/CMakeLists.txt
View file @
cb431144
...
...
@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp
...
...
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
View file @
cb431144
...
...
@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
...
...
@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
...
...
@@ -44,8 +44,7 @@ namespace
{
void
write_debug
(
const
ngraph
::
Node
&
op
)
{
PLAIDML_DEBUG
<<
"Node: name=
\"
"
<<
op
.
get_name
()
<<
"
\"
desc=
\"
"
<<
op
.
description
()
<<
"
\"
"
;
PLAIDML_DEBUG
<<
"Compiling: "
<<
op
;
for
(
const
auto
&
op_input
:
op
.
get_inputs
())
{
ngraph
::
descriptor
::
Tensor
*
tensor
=
op_input
.
get_output
().
get_tensor_ptr
().
get
();
...
...
@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateElision
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateCombination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ImplicitBroadcast
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
LowerConvolutions
>
();
if
(
pass_manager
.
get_pass_config
().
get_pass_enable
(
"Winograd"
))
{
...
...
src/ngraph/runtime/plaidml/plaidml_config.cpp
View file @
cb431144
...
...
@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config
// So to verify that there is a non-zero-length option value, test oval_len
// To verify that there is no option value, test has_oval
if
(
oname_begin
==
oname_end
&&
!
has_oval
)
{
// An empty option; poor style, but advance to the next.
continue
;
}
// Check for verbosity
if
(
is_opt
(
"v"
))
{
...
...
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
View file @
cb431144
...
...
@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
void
ngraph
::
runtime
::
plaidml
::
op
::
Replicate
::
validate_and_infer_types
()
{
const
auto
&
arg
=
get_arguments
().
a
t
(
0
);
std
::
shared_ptr
<
Node
>
arg
=
get_argumen
t
(
0
);
Shape
shape
=
arg
->
get_shape
();
for
(
auto
rit
=
m_replication_axes
.
begin
(),
sit
=
shape
.
begin
();
rit
!=
m_replication_axes
.
end
();
...
...
src/ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
View file @
cb431144
...
...
@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/
graph_util
.hpp"
#include "ngraph/
check
.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto
implicit_broadcast
=
std
::
make_shared
<
plaidml
::
op
::
ImplicitBroadcast
>
(
src
,
broadcast
->
get_shape
());
replace_node
(
broadcast
,
implicit_broadcast
);
// N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
broadcast
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
broadcast
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
return
true
;
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
std
::
make_shared
<
pattern
::
Matcher
>
(
target_op
),
callback
);
}
src/ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
View file @
cb431144
...
...
@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
// op. Using target always works.
AxisVector
out_axes
=
to_axes
(
target
,
output_transpose
);
auto
lhs
=
node
->
get_argument
s
().
at
(
0
);
auto
lhs
=
node
->
get_argument
(
0
);
auto
*
lhs_transpose
=
to_transpose
(
lhs
);
if
(
lhs_transpose
)
{
lhs
=
lhs_transpose
->
get_argument
s
().
at
(
0
);
lhs
=
lhs_transpose
->
get_argument
(
0
);
}
AxisVector
lhs_axes
=
to_axes
(
lhs
,
lhs_transpose
);
auto
rhs
=
node
->
get_argument
s
().
at
(
1
);
auto
rhs
=
node
->
get_argument
(
1
);
auto
*
rhs_transpose
=
to_transpose
(
rhs
);
if
(
rhs_transpose
)
{
rhs
=
rhs_transpose
->
get_argument
s
().
at
(
0
);
rhs
=
rhs_transpose
->
get_argument
(
0
);
}
AxisVector
rhs_axes
=
to_axes
(
rhs
,
rhs_transpose
);
...
...
src/ngraph/
pass/
prefix_reshape_elimination.cpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.cpp
View file @
cb431144
...
...
@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/
pass/
prefix_reshape_elimination.hpp"
#include "ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using
namespace
std
;
using
namespace
ngraph
;
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
{
auto
src_op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
)
{
return
true
;
});
...
...
@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
node
.
get
());
ngraph
::
op
::
Reshape
*
reshape
=
dynamic_cast
<
ngraph
::
op
::
Reshape
*>
(
node
.
get
());
if
(
!
reshape
)
{
return
false
;
...
...
@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
return
pattern
::
has_class
<
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
return
pattern
::
has_class
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
},
NodeVector
{
reshape_op
});
auto
callback
=
[](
pattern
::
Matcher
&
m
)
{
replace_node
(
m
.
get_matched_nodes
().
at
(
1
),
m
.
get_matched_nodes
().
at
(
2
));
return
true
;
auto
src
=
m
.
get_matched_nodes
().
at
(
2
);
auto
prefix_reshape
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
m
.
get_matched_nodes
().
at
(
1
));
auto
implicit_broadcast
=
std
::
make_shared
<
op
::
ImplicitBroadcast
>
(
src
,
prefix_reshape
->
get_shape
());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
prefix_reshape
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
prefix_reshape
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
make_shared
<
pattern
::
Matcher
>
(
target_op
,
"PrefixReshapeElimination"
),
callback
,
PassProperty
::
REQUIRE_STATIC_SHAPE
);
ngraph
::
pass
::
PassProperty
::
REQUIRE_STATIC_SHAPE
);
}
src/ngraph/
pass/
prefix_reshape_elimination.hpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp
View file @
cb431144
...
...
@@ -20,19 +20,23 @@
namespace
ngraph
{
namespace
runtime
{
namespace
plaidml
{
namespace
pass
{
class
PrefixReshapeElimination
;
}
}
}
}
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class
ngraph
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
// A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes, and replaces them with
// ImplicitBroadcast operations (which do the same thing as a passthrough).
class
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
{
public
:
PrefixReshapeElimination
();
...
...
src/ngraph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
View file @
cb431144
...
...
@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
*
ait
*=
*
uit
;
}
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_arguments
().
at
(
0
)
,
std
::
move
(
axes
)));
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_argument
(
0
),
std
::
move
(
axes
)));
return
true
;
};
...
...
src/ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
View file @
cb431144
...
...
@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
if
(
elidable
)
{
replaced_any
=
true
;
replace_node
(
replicate
,
replicate
->
get_argument
s
().
at
(
0
));
replace_node
(
replicate
,
replicate
->
get_argument
(
0
));
}
}
...
...
test/cpu_fusion.cpp
View file @
cb431144
...
...
@@ -560,6 +560,136 @@ TEST(cpu_fusion, conv_bias_bprop)
ASSERT_EQ
(
ccg
,
1
);
}
static
void
test_batchnorm_multiply_add_relu
(
Shape
input_shape
)
{
auto
make_bn_relu_function
=
[
&
]()
{
auto
c_axis
=
input_shape
[
1
];
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
mean_shape
=
Shape
{
c_axis
};
auto
mean
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
mean_shape
);
auto
var_shape
=
Shape
{
c_axis
};
auto
var
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
var_shape
);
auto
gamma_shape
=
Shape
{
c_axis
};
auto
gamma
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
gamma_shape
);
auto
beta_shape
=
Shape
{
c_axis
};
auto
beta
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
beta_shape
);
double
eps
=
0.001
;
auto
bn
=
std
::
make_shared
<
ngraph
::
op
::
BatchNormInference
>
(
eps
,
gamma
,
beta
,
input
,
mean
,
var
);
std
::
vector
<
size_t
>
vec
{
0
};
for
(
auto
i
=
2
;
i
<
input_shape
.
size
();
i
++
)
{
vec
.
push_back
(
i
);
}
auto
broadcast1_input
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
gamma_shape
);
auto
broadcast1
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast1_input
,
input_shape
,
AxisSet
(
vec
));
auto
multiply
=
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
bn
,
broadcast1
);
auto
broadcast2_input
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
gamma_shape
);
auto
broadcast2
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast2_input
,
input_shape
,
AxisSet
(
vec
));
auto
add
=
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
multiply
,
broadcast2
);
auto
relu
=
std
::
make_shared
<
ngraph
::
op
::
Relu
>
(
add
);
auto
f
=
make_shared
<
Function
>
(
relu
,
ParameterVector
{
gamma
,
beta
,
input
,
mean
,
var
,
broadcast1_input
,
broadcast2_input
});
return
f
;
};
auto
cpu_f
=
make_bn_relu_function
();
auto
int_f
=
make_bn_relu_function
();
test
::
Uniform
<
float
>
rng
(
1.0
f
,
10.0
f
);
vector
<
vector
<
float
>>
args
;
for
(
shared_ptr
<
op
::
Parameter
>
param
:
int_f
->
get_parameters
())
{
vector
<
float
>
tensor_val
(
shape_size
(
param
->
get_shape
()));
rng
.
initialize
(
tensor_val
);
args
.
push_back
(
tensor_val
);
}
auto
int_results
=
execute
(
int_f
,
args
,
"INTERPRETER"
);
auto
cpu_results
=
execute
(
cpu_f
,
args
,
"CPU"
);
for
(
size_t
i
=
0
;
i
<
cpu_results
.
size
();
i
++
)
{
EXPECT_TRUE
(
test
::
all_close
(
cpu_results
.
at
(
i
),
int_results
.
at
(
i
),
1.0e-4
f
,
1.0e-4
f
));
}
size_t
bn_relu
=
count_ops_of_type
<
op
::
BatchNormInferenceRelu
>
(
cpu_f
);
ASSERT_EQ
(
bn_relu
,
1
);
}
TEST
(
cpu_fusion
,
batchnorm_multiply_add_relu
)
{
test_batchnorm_multiply_add_relu
(
Shape
{
1
,
3
,
2
,
2
});
test_batchnorm_multiply_add_relu
(
Shape
{
1
,
2
,
2
,
2
,
2
});
test_batchnorm_multiply_add_relu
(
Shape
{
2
,
2
,
2
,
4
,
4
});
}
TEST
(
cpu_fusion
,
batchnorm_multiply_add_relu_no_fusion
)
{
auto
input_shape
=
Shape
{
3
,
3
,
2
,
2
};
auto
make_bn_relu_function
=
[
&
]()
{
auto
c_axis
=
input_shape
[
1
];
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
mean_shape
=
Shape
{
c_axis
};
auto
mean
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
mean_shape
);
auto
var_shape
=
Shape
{
c_axis
};
auto
var
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
var_shape
);
auto
gamma_shape
=
Shape
{
c_axis
};
auto
gamma
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
gamma_shape
);
auto
beta_shape
=
Shape
{
c_axis
};
auto
beta
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
beta_shape
);
double
eps
=
0.001
;
auto
bn
=
std
::
make_shared
<
ngraph
::
op
::
BatchNormInference
>
(
eps
,
gamma
,
beta
,
input
,
mean
,
var
);
std
::
vector
<
size_t
>
vec
;
for
(
auto
i
=
1
;
i
<
input_shape
.
size
();
i
++
)
{
vec
.
push_back
(
i
);
}
auto
broadcast1_input
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
});
auto
broadcast1
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast1_input
,
input_shape
,
AxisSet
(
vec
));
auto
multiply
=
std
::
make_shared
<
ngraph
::
op
::
Multiply
>
(
bn
,
broadcast1
);
auto
broadcast2_input
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
});
auto
broadcast2
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
broadcast2_input
,
input_shape
,
AxisSet
(
vec
));
auto
add
=
std
::
make_shared
<
ngraph
::
op
::
Add
>
(
multiply
,
broadcast2
);
auto
relu
=
std
::
make_shared
<
ngraph
::
op
::
Relu
>
(
add
);
auto
f
=
make_shared
<
Function
>
(
relu
,
ParameterVector
{
gamma
,
beta
,
input
,
mean
,
var
,
broadcast1_input
,
broadcast2_input
});
return
f
;
};
auto
cpu_f
=
make_bn_relu_function
();
auto
int_f
=
make_bn_relu_function
();
test
::
Uniform
<
float
>
rng
(
1.0
f
,
10.0
f
);
vector
<
vector
<
float
>>
args
;
for
(
shared_ptr
<
op
::
Parameter
>
param
:
int_f
->
get_parameters
())
{
vector
<
float
>
tensor_val
(
shape_size
(
param
->
get_shape
()));
rng
.
initialize
(
tensor_val
);
args
.
push_back
(
tensor_val
);
}
auto
int_results
=
execute
(
int_f
,
args
,
"INTERPRETER"
);
auto
cpu_results
=
execute
(
cpu_f
,
args
,
"CPU"
);
for
(
size_t
i
=
0
;
i
<
cpu_results
.
size
();
i
++
)
{
EXPECT_TRUE
(
test
::
all_close
(
cpu_results
.
at
(
i
),
int_results
.
at
(
i
),
1.0e-4
f
,
1.0e-4
f
));
}
size_t
bn_relu
=
count_ops_of_type
<
op
::
BatchNormInferenceRelu
>
(
cpu_f
);
ASSERT_EQ
(
bn_relu
,
0
);
}
TEST
(
cpu_fusion
,
batchnorm_fprop_relu_b1c2h2w2
)
{
auto
input_shape
=
Shape
{
1
,
2
,
2
,
2
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment