Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
13bdf0ef
Unverified
Commit
13bdf0ef
authored
5 years ago
by
Scott Cyphers
Committed by
GitHub
5 years ago
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3174 from NervanaSystems/rearhart/plaidml
Minor PlaidML fixes
parents
abd69371
4eb946b0
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
89 additions
and
35 deletions
+89
-35
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-2
CMakeLists.txt
src/ngraph/runtime/plaidml/CMakeLists.txt
+1
-0
plaidml_compiler.cpp
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
+3
-4
plaidml_config.cpp
src/ngraph/runtime/plaidml/plaidml_config.cpp
+6
-0
plaidml_ops_replicate.cpp
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
+1
-1
plaidml_pass_implicit_broadcast.cpp
...graph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
+22
-3
plaidml_pass_lower_convolutions.cpp
...graph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
+4
-4
plaidml_pass_prefix_reshape_elimination.cpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.cpp
+35
-8
plaidml_pass_prefix_reshape_elimination.hpp
...ntime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp
+13
-9
plaidml_pass_replicate_combination.cpp
...ph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
+3
-3
plaidml_pass_replicate_elision.cpp
...ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
+1
-1
No files found.
src/ngraph/CMakeLists.txt
View file @
13bdf0ef
...
...
@@ -419,8 +419,6 @@ set (SRC
pass/pass.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/CMakeLists.txt
View file @
13bdf0ef
...
...
@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_compiler.cpp
View file @
13bdf0ef
...
...
@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
...
...
@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
...
...
@@ -44,8 +44,7 @@ namespace
{
void
write_debug
(
const
ngraph
::
Node
&
op
)
{
PLAIDML_DEBUG
<<
"Node: name=
\"
"
<<
op
.
get_name
()
<<
"
\"
desc=
\"
"
<<
op
.
description
()
<<
"
\"
"
;
PLAIDML_DEBUG
<<
"Compiling: "
<<
op
;
for
(
const
auto
&
op_input
:
op
.
get_inputs
())
{
ngraph
::
descriptor
::
Tensor
*
tensor
=
op_input
.
get_output
().
get_tensor_ptr
().
get
();
...
...
@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateElision
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ReplicateCombination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
ImplicitBroadcast
>
();
pass_manager
.
register_pass
<
ngraph
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
>
();
pass_manager
.
register_pass
<
ngraph
::
runtime
::
plaidml
::
pass
::
LowerConvolutions
>
();
if
(
pass_manager
.
get_pass_config
().
get_pass_enable
(
"Winograd"
))
{
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_config.cpp
View file @
13bdf0ef
...
...
@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config
// So to verify that there is a non-zero-length option value, test oval_len
// To verify that there is no option value, test has_oval
if
(
oname_begin
==
oname_end
&&
!
has_oval
)
{
// An empty option; poor style, but advance to the next.
continue
;
}
// Check for verbosity
if
(
is_opt
(
"v"
))
{
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_ops_replicate.cpp
View file @
13bdf0ef
...
...
@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
void
ngraph
::
runtime
::
plaidml
::
op
::
Replicate
::
validate_and_infer_types
()
{
const
auto
&
arg
=
get_arguments
().
a
t
(
0
);
std
::
shared_ptr
<
Node
>
arg
=
get_argumen
t
(
0
);
Shape
shape
=
arg
->
get_shape
();
for
(
auto
rit
=
m_replication_axes
.
begin
(),
sit
=
shape
.
begin
();
rit
!=
m_replication_axes
.
end
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.cpp
View file @
13bdf0ef
...
...
@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/
graph_util
.hpp"
#include "ngraph/
check
.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto
implicit_broadcast
=
std
::
make_shared
<
plaidml
::
op
::
ImplicitBroadcast
>
(
src
,
broadcast
->
get_shape
());
replace_node
(
broadcast
,
implicit_broadcast
);
// N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
broadcast
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
broadcast
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
return
true
;
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
std
::
make_shared
<
pattern
::
Matcher
>
(
target_op
),
callback
);
}
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.cpp
View file @
13bdf0ef
...
...
@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
// op. Using target always works.
AxisVector
out_axes
=
to_axes
(
target
,
output_transpose
);
auto
lhs
=
node
->
get_argument
s
().
at
(
0
);
auto
lhs
=
node
->
get_argument
(
0
);
auto
*
lhs_transpose
=
to_transpose
(
lhs
);
if
(
lhs_transpose
)
{
lhs
=
lhs_transpose
->
get_argument
s
().
at
(
0
);
lhs
=
lhs_transpose
->
get_argument
(
0
);
}
AxisVector
lhs_axes
=
to_axes
(
lhs
,
lhs_transpose
);
auto
rhs
=
node
->
get_argument
s
().
at
(
1
);
auto
rhs
=
node
->
get_argument
(
1
);
auto
*
rhs_transpose
=
to_transpose
(
rhs
);
if
(
rhs_transpose
)
{
rhs
=
rhs_transpose
->
get_argument
s
().
at
(
0
);
rhs
=
rhs_transpose
->
get_argument
(
0
);
}
AxisVector
rhs_axes
=
to_axes
(
rhs
,
rhs_transpose
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/
pass/
prefix_reshape_elimination.cpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.cpp
View file @
13bdf0ef
...
...
@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/
pass/
prefix_reshape_elimination.hpp"
#include "ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...
...
@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using
namespace
std
;
using
namespace
ngraph
;
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
::
PrefixReshapeElimination
()
{
auto
src_op
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
)
{
return
true
;
});
...
...
@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
node
.
get
());
ngraph
::
op
::
Reshape
*
reshape
=
dynamic_cast
<
ngraph
::
op
::
Reshape
*>
(
node
.
get
());
if
(
!
reshape
)
{
return
false
;
...
...
@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element
::
i8
,
Shape
{},
[](
shared_ptr
<
Node
>
node
)
{
return
pattern
::
has_class
<
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
return
pattern
::
has_class
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
>
()(
node
)
||
pattern
::
has_class
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
>
()(
node
);
},
NodeVector
{
reshape_op
});
auto
callback
=
[](
pattern
::
Matcher
&
m
)
{
replace_node
(
m
.
get_matched_nodes
().
at
(
1
),
m
.
get_matched_nodes
().
at
(
2
));
return
true
;
auto
src
=
m
.
get_matched_nodes
().
at
(
2
);
auto
prefix_reshape
=
std
::
static_pointer_cast
<
ngraph
::
op
::
Reshape
>
(
m
.
get_matched_nodes
().
at
(
1
));
auto
implicit_broadcast
=
std
::
make_shared
<
op
::
ImplicitBroadcast
>
(
src
,
prefix_reshape
->
get_shape
());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool
result
=
false
;
for
(
size_t
i
=
0
;
i
<
prefix_reshape
->
get_output_size
();
++
i
)
{
for
(
auto
&
input
:
prefix_reshape
->
output
(
i
).
get_target_inputs
())
{
Node
*
node
=
input
.
get_node
();
if
(
dynamic_cast
<
ngraph
::
op
::
util
::
UnaryElementwiseArithmetic
*>
(
node
)
||
dynamic_cast
<
ngraph
::
op
::
util
::
BinaryElementwiseArithmetic
*>
(
node
))
{
input
.
replace_source_output
(
implicit_broadcast
->
output
(
i
));
result
=
true
;
}
}
}
NGRAPH_CHECK
(
result
,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass"
);
return
result
;
};
add_matcher
(
make_shared
<
pattern
::
Matcher
>
(
target_op
,
"PrefixReshapeElimination"
),
callback
,
PassProperty
::
REQUIRE_STATIC_SHAPE
);
ngraph
::
pass
::
PassProperty
::
REQUIRE_STATIC_SHAPE
);
}
This diff is collapsed.
Click to expand it.
src/ngraph/
pass/
prefix_reshape_elimination.hpp
→
src/ngraph/
runtime/plaidml/plaidml_pass_
prefix_reshape_elimination.hpp
View file @
13bdf0ef
...
...
@@ -20,19 +20,23 @@
namespace
ngraph
{
namespace
pass
namespace
runtime
{
class
PrefixReshapeElimination
;
namespace
plaidml
{
namespace
pass
{
class
PrefixReshapeElimination
;
}
}
}
}
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class
ngraph
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
// A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes, and replaces them with
// ImplicitBroadcast operations (which do the same thing as a passthrough).
class
ngraph
::
runtime
::
plaidml
::
pass
::
PrefixReshapeElimination
final
:
public
ngraph
::
pass
::
GraphRewrite
{
public
:
PrefixReshapeElimination
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_pass_replicate_combination.cpp
View file @
13bdf0ef
...
...
@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
*
ait
*=
*
uit
;
}
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_arguments
().
at
(
0
)
,
std
::
move
(
axes
)));
replace_node
(
lower
,
std
::
make_shared
<
plaidml
::
op
::
Replicate
>
(
upper
->
get_argument
(
0
),
std
::
move
(
axes
)));
return
true
;
};
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/runtime/plaidml/plaidml_pass_replicate_elision.cpp
View file @
13bdf0ef
...
...
@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
if
(
elidable
)
{
replaced_any
=
true
;
replace_node
(
replicate
,
replicate
->
get_argument
s
().
at
(
0
));
replace_node
(
replicate
,
replicate
->
get_argument
(
0
));
}
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment