Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
13bdf0ef
Unverified
Commit
13bdf0ef
authored
Jul 09, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 09, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3174 from NervanaSystems/rearhart/plaidml
Minor PlaidML fixes
parents
abd69371
4eb946b0
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
87 additions
and
33 deletions
+87
-33
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-2
CMakeLists.txt
src/ngraph/runtime/plaidml/CMakeLists.txt
+1
-0
plaidml_compiler.cpp
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
+3
-4
plaidml_config.cpp
src/ngraph/runtime/plaidml/plaidml_config.cpp
+6
-0
plaidml_ops_replicate.cpp
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
+1
-1
plaidml_pass_implicit_broadcast.cpp
...graph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
+22
-3
plaidml_pass_lower_convolutions.cpp
...graph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
+4
-4
plaidml_pass_prefix_reshape_elimination.cpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.cpp
+35
-8
plaidml_pass_prefix_reshape_elimination.hpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp
+11
-7
plaidml_pass_replicate_combination.cpp
...ph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
+3
-3
plaidml_pass_replicate_elision.cpp
...ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
+1
-1
No files found.
src/ngraph/CMakeLists.txt
View file @
13bdf0ef
...
...
@@ -419,8 +419,6 @@ set (SRC
pass/pass.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp
...
...
src/ngraph/runtime/plaidml/CMakeLists.txt
View file @
13bdf0ef
...
...
@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp
...
...
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
View file @
13bdf0ef
...
...
@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
...
...
@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
...
...
@@ -44,8 +44,7 @@ namespace
{
void
write_debug
(
const
ngraph
::
Node
&
op
)
{
PLAIDML_DEBUG
<<
"Node: name=
\"
"
<<
op
.
get_name
()
<<
"
\"
desc=
\"
"
<<
op
.
description
()
<<
"
\"
"
;
PLAIDML_DEBUG
<<
"Compiling: "
<<
op
;
for
(
const
auto
&
op_input
:
op
.
get_inputs
())
{
ngraph
::
descriptor
::
Tensor
*
tensor
=
op_input
.
get_output
().
get_tensor_ptr
().
get
();
...
...
@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateElision
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateCombination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ImplicitBroadcast
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
LowerConvolutions
>
();
if
(
pass_manager
.
get_pass_config
().
get_pass_enable
(
"Winograd"
))
{
...
...
src/ngraph/runtime/plaidml/plaidml_config.cpp
View file @
13bdf0ef
...
...
@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config
// So to verify that there is a non-zero-length option value, test oval_len
// To verify that there is no option value, test has_oval
if
(
oname_begin
==
oname_end
&&
!
has_oval
)
{
// An empty option; poor style, but advance to the next.
continue
;
}
// Check for verbosity
if
(
is_opt
(
"v"
))
{
...
...
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
View file @
13bdf0ef
...
...
@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
void
ngraph
::
runtime
::
plaidml
::
op
::
Replicate
::
validate_and_infer_types
()
{
const
auto
&
arg
=
get_arguments
().
a
t
(
0
);
std
::
shared_ptr
<
Node
>
arg
=
get_argumen
t
(
0
);
Shape
shape
=
arg
->
get_shape
();
for
(
auto
rit
=
m_replication_axes
.
begin
(),
sit
=
shape
.
begin
();
rit
!=
m_replication_axes
.
end
();
...
...
src/ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
View file @
13bdf0ef
...
...
@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/
graph_util
.hpp"
#include "ngraph/
check
.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto
implicit_broadcast
=
std
::
make_shared
<
plaidml
::
op
::
ImplicitBroadcast
>
(
src
,
broadcast
->
get_shape
());
replace_node
(
broadcast
,
implicit_broadcast
);
// N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
broadcast
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
broadcast
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
return
true
;
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
std
::
make_shared
<
pattern
::
Matcher
>
(
target_op
),
callback
);
}
src/ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
View file @
13bdf0ef
...
...
@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
// op. Using target always works.
AxisVector
out_axes
=
to_axes
(
target
,
output_transpose
);
auto
lhs
=
node
->
get_argument
s
().
at
(
0
);
auto
lhs
=
node
->
get_argument
(
0
);
auto
*
lhs_transpose
=
to_transpose
(
lhs
);
if
(
lhs_transpose
)
{
lhs
=
lhs_transpose
->
get_argument
s
().
at
(
0
);
lhs
=
lhs_transpose
->
get_argument
(
0
);
}
AxisVector
lhs_axes
=
to_axes
(
lhs
,
lhs_transpose
);
auto
rhs
=
node
->
get_argument
s
().
at
(
1
);
auto
rhs
=
node
->
get_argument
(
1
);
auto
*
rhs_transpose
=
to_transpose
(
rhs
);
if
(
rhs_transpose
)
{
rhs
=
rhs_transpose
->
get_argument
s
().
at
(
0
);
rhs
=
rhs_transpose
->
get_argument
(
0
);
}
AxisVector
rhs_axes
=
to_axes
(
rhs
,
rhs_transpose
);
...
...
src/ngraph/
pass/
prefix_reshape_elimination.cpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.cpp
View file @
13bdf0ef
...
...
@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/
pass/
prefix_reshape_elimination.hpp"
#include "ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using
namespace
std
;
using
namespace
ngraph
;
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
{
auto
src_op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
)
{
return
true
;
});
...
...
@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
node
.
get
());
ngraph
::
op
::
Reshape
*
reshape
=
dynamic_cast
<
ngraph
::
op
::
Reshape
*>
(
node
.
get
());
if
(
!
reshape
)
{
return
false
;
...
...
@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
return
pattern
::
has_class
<
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
return
pattern
::
has_class
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
},
NodeVector
{
reshape_op
});
auto
callback
=
[](
pattern
::
Matcher
&
m
)
{
replace_node
(
m
.
get_matched_nodes
().
at
(
1
),
m
.
get_matched_nodes
().
at
(
2
));
return
true
;
auto
src
=
m
.
get_matched_nodes
().
at
(
2
);
auto
prefix_reshape
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
m
.
get_matched_nodes
().
at
(
1
));
auto
implicit_broadcast
=
std
::
make_shared
<
op
::
ImplicitBroadcast
>
(
src
,
prefix_reshape
->
get_shape
());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
prefix_reshape
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
prefix_reshape
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
make_shared
<
pattern
::
Matcher
>
(
target_op
,
"PrefixReshapeElimination"
),
callback
,
PassProperty
::
REQUIRE_STATIC_SHAPE
);
ngraph
::
pass
::
PassProperty
::
REQUIRE_STATIC_SHAPE
);
}
src/ngraph/
pass/
prefix_reshape_elimination.hpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp
View file @
13bdf0ef
...
...
@@ -20,19 +20,23 @@
namespace
ngraph
{
namespace
runtime
{
namespace
plaidml
{
namespace
pass
{
class
PrefixReshapeElimination
;
}
}
}
}
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class
ngraph
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
// A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes, and replaces them with
// ImplicitBroadcast operations (which do the same thing as a passthrough).
class
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
{
public
:
PrefixReshapeElimination
();
...
...
src/ngraph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
View file @
13bdf0ef
...
...
@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
*
ait
*=
*
uit
;
}
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_arguments
().
at
(
0
)
,
std
::
move
(
axes
)));
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_argument
(
0
),
std
::
move
(
axes
)));
return
true
;
};
...
...
src/ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
View file @
13bdf0ef
...
...
@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
if
(
elidable
)
{
replaced_any
=
true
;
replace_node
(
replicate
,
replicate
->
get_argument
s
().
at
(
0
));
replace_node
(
replicate
,
replicate
->
get_argument
(
0
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment