Unverified Commit 13bdf0ef authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3174 from NervanaSystems/rearhart/plaidml

Minor PlaidML fixes
parents abd69371 4eb946b0
......@@ -419,8 +419,6 @@ set (SRC
pass/pass.hpp
pass/pass_config.cpp
pass/pass_config.hpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.hpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.hpp
pass/reshape_elimination.cpp
......
......@@ -55,6 +55,7 @@ set(SRC
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_prefix_reshape_elimination.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_winograd.cpp
......
......@@ -26,7 +26,6 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
......@@ -36,6 +35,7 @@
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
......@@ -44,8 +44,7 @@ namespace
{
void write_debug(const ngraph::Node& op)
{
PLAIDML_DEBUG << "Node: name=\"" << op.get_name() << "\" desc=\"" << op.description()
<< "\"";
PLAIDML_DEBUG << "Compiling: " << op;
for (const auto& op_input : op.get_inputs())
{
ngraph::descriptor::Tensor* tensor = op_input.get_output().get_tensor_ptr().get();
......@@ -104,7 +103,7 @@ std::shared_ptr<ngraph::runtime::plaidml::PlaidML_Executable>
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (pass_manager.get_pass_config().get_pass_enable("Winograd"))
{
......
......@@ -163,6 +163,12 @@ ngraph::runtime::plaidml::Config
// So to verify that there is a non-zero-length option value, test oval_len
// To verify that there is no option value, test has_oval
if (oname_begin == oname_end && !has_oval)
{
// An empty option; poor style, but advance to the next.
continue;
}
// Check for verbosity
if (is_opt("v"))
{
......
......@@ -53,7 +53,7 @@ ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
void ngraph::runtime::plaidml::op::Replicate::validate_and_infer_types()
{
const auto& arg = get_arguments().at(0);
std::shared_ptr<Node> arg = get_argument(0);
Shape shape = arg->get_shape();
for (auto rit = m_replication_axes.begin(), sit = shape.begin();
rit != m_replication_axes.end();
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/check.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
......@@ -76,9 +76,28 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
auto implicit_broadcast =
std::make_shared<plaidml::op::ImplicitBroadcast>(src, broadcast->get_shape());
replace_node(broadcast, implicit_broadcast);
// N.B. We don't use replace_node() here, since it's important to only replace the broadcast with an
// implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < broadcast->get_output_size(); ++i)
{
for (auto& input : broadcast->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
return true;
NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
}
......@@ -75,19 +75,19 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
// op. Using target always works.
AxisVector out_axes = to_axes(target, output_transpose);
auto lhs = node->get_arguments().at(0);
auto lhs = node->get_argument(0);
auto* lhs_transpose = to_transpose(lhs);
if (lhs_transpose)
{
lhs = lhs_transpose->get_arguments().at(0);
lhs = lhs_transpose->get_argument(0);
}
AxisVector lhs_axes = to_axes(lhs, lhs_transpose);
auto rhs = node->get_arguments().at(1);
auto rhs = node->get_argument(1);
auto* rhs_transpose = to_transpose(rhs);
if (rhs_transpose)
{
rhs = rhs_transpose->get_arguments().at(0);
rhs = rhs_transpose->get_argument(0);
}
AxisVector rhs_axes = to_axes(rhs, rhs_transpose);
......
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
......@@ -23,11 +23,12 @@
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_implicit_broadcast.hpp"
using namespace std;
using namespace ngraph;
pass::PrefixReshapeElimination::PrefixReshapeElimination()
runtime::plaidml::pass::PrefixReshapeElimination::PrefixReshapeElimination()
{
auto src_op = make_shared<pattern::op::Label>(
element::i8, Shape{}, [](shared_ptr<Node>) { return true; });
......@@ -35,7 +36,7 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8,
Shape{},
[](shared_ptr<Node> node) {
op::Reshape* reshape = dynamic_cast<op::Reshape*>(node.get());
ngraph::op::Reshape* reshape = dynamic_cast<ngraph::op::Reshape*>(node.get());
if (!reshape)
{
return false;
......@@ -71,16 +72,42 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
element::i8,
Shape{},
[](shared_ptr<Node> node) {
return pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{reshape_op});
auto callback = [](pattern::Matcher& m) {
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true;
auto src = m.get_matched_nodes().at(2);
auto prefix_reshape =
std::static_pointer_cast<ngraph::op::Reshape>(m.get_matched_nodes().at(1));
auto implicit_broadcast =
std::make_shared<op::ImplicitBroadcast>(src, prefix_reshape->get_shape());
// N.B. We don't use replace_node() here, since it's important to only replace the prefix reshape with
// an implicit broadcast when the consuming operation is an elementwise operation, since PlaidML
// contractions don't provide implicit broadcast semantics.
bool result = false;
for (size_t i = 0; i < prefix_reshape->get_output_size(); ++i)
{
for (auto& input : prefix_reshape->output(i).get_target_inputs())
{
Node* node = input.get_node();
if (dynamic_cast<ngraph::op::util::UnaryElementwiseArithmetic*>(node) ||
dynamic_cast<ngraph::op::util::BinaryElementwiseArithmetic*>(node))
{
input.replace_source_output(implicit_broadcast->output(i));
result = true;
}
}
}
NGRAPH_CHECK(result,
"Expected at least one elementwise consumer in the PlaidML implicit broadcast "
"rewrite graph pass");
return result;
};
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"),
callback,
PassProperty::REQUIRE_STATIC_SHAPE);
ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE);
}
......@@ -20,19 +20,23 @@
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class PrefixReshapeElimination;
}
}
}
}
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class ngraph::pass::PrefixReshapeElimination final : public ngraph::pass::GraphRewrite
// A pass that matches reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes, and replaces them with
// ImplicitBroadcast operations (which do the same thing as a passthrough).
class ngraph::runtime::plaidml::pass::PrefixReshapeElimination final
: public ngraph::pass::GraphRewrite
{
public:
PrefixReshapeElimination();
......
......@@ -47,9 +47,9 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
*ait *= *uit;
}
replace_node(lower,
std::make_shared<plaidml::op::Replicate>(upper->get_arguments().at(0),
std::move(axes)));
replace_node(
lower,
std::make_shared<plaidml::op::Replicate>(upper->get_argument(0), std::move(axes)));
return true;
};
......
......@@ -74,7 +74,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
if (elidable)
{
replaced_any = true;
replace_node(replicate, replicate->get_arguments().at(0));
replace_node(replicate, replicate->get_argument(0));
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment