Commit b7551a60 authored by Rob Earhart's avatar Rob Earhart

ReshapeElimination -> PlaidML; fixup ImplicitBcast

The PrefixReshapeElimination pass doesn't work correctly if the network's operation's input and output shape
compatibility is being validated; an implicit broadcast operation is required to get the shapes right.  This
change moves the implementation back to PlaidML (since it's seeming less generally useful, and PlaidML has an
implicit broadcast operation), and fixes both it and the implicit broadcast pass to only apply when the
downstream operation is an elementwise operation (since PlaidML doesn't support automatic broadcasting for
contractions).
parent 9b53073c
......@@ -410,8 +410,6 @@ set (SRC
pass/pass.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp
......
......@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp
......
......@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
......@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
......@@ -44,8 +44,7 @@ namespace
{
void write_debug(const ngraph::Node& op)
{
PLAIDML_DEBUG << "Node: name=\"" << op.get_name() << "\" desc=\"" << op.description()
<< "\"";
PLAIDML_DEBUG << "Compiling: " << op;
for (const auto& op_input : op.get_inputs())
{
ngraph::descriptor::Tensor* tensor = op_input.get_output().get_tensor_ptr().get();
......@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (pass_manager.get_pass_config().get_pass_enable("Winograd"))
{
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/check.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
......@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto implicit_broadcast =
std::make_shared<plaidml::op::ImplicitBroadcast>(src, broadcast->get_shape());
replace_node(broadcast, implicit_broadcast);
// N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < broadcast->get_output_size(); ++i)
{
for (auto& input : broadcast->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
return true;
NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
}
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
......@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using namespace std;
using namespace ngraph;
pass::PrefixReshapeElimination::PrefixReshapeElimination()
runtime::plaidml::pass::PrefixReshapeElimination::PrefixReshapeElimination()
{
auto src_op = make_shared<pattern::op::Label>(
element::i8, Shape{}, [](shared_ptr<Node>) { return true; });
......@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8,
Shape{},
[](shared_ptr<Node> node) {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get());
ngraph::op::Reshape* reshape = dynamic_cast<ngraph::op::Reshape*>(node.get());
if (!reshape)
{
return false;
......@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8,
Shape{},
[](shared_ptr<Node> node) {
return pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{reshape_op});
auto callback = [](pattern::Matcher& m) {
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true;
auto src = m.get_matched_nodes().at(2);
auto prefix_reshape =
std::static_pointer_cast<ngraph::op::Reshape>(m.get_matched_nodes().at(1));
auto implicit_broadcast =
std::make_shared<op::ImplicitBroadcast>(src, prefix_reshape->get_shape());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < prefix_reshape->get_output_size(); ++i)
{
for (auto& input : prefix_reshape->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
};
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"),
callback,
PassProperty::REQUIRE_STATIC_SHAPE);
ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE);
}
......@@ -20,19 +20,23 @@
namespace ngraph
{
namespace pass
namespace runtime
{
class PrefixReshapeElimination;
namespace plaidml
{
namespace pass
{
class PrefixReshapeElimination;
}
}
}
}
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class ngraph::pass::PrefixReshapeElimination final : public ngraph::pass::GraphRewrite
// A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes, and replaces them with
// ImplicitBroadcast operations (which do the same thing as a passthrough).
class ngraph::runtime::plaidml::pass::PrefixReshapeElimination final
: public ngraph::pass::GraphRewrite
{
public:
PrefixReshapeElimination();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment