Commit f79b40a7 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

Minor PlaidML updates (#2283)

* Use static cast where possible

* Tensor API update

* Move prefix reshape elision to be a general pass

* Use pass config to select Winograd optimization

* Use get_is_transpose() to detect transposes

* Use get_default_order to build AxisVectors
parent d74ea190
...@@ -146,7 +146,8 @@ set (SRC ...@@ -146,7 +146,8 @@ set (SRC
pass/memory_visualize.cpp pass/memory_visualize.cpp
pass/nop_elimination.cpp pass/nop_elimination.cpp
pass/pass.cpp pass/pass.cpp
pass/pass_config.cpp pass/pass_config.cpp
pass/prefix_reshape_elimination.cpp
pass/propagate_cacheability.cpp pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp pass/reshape_elimination.cpp
pass/reshape_sinking.cpp pass/reshape_sinking.cpp
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp" #include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp" #include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "ngraph/pattern/op/any_of.hpp" #include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
ngraph::runtime::plaidml::pass::ReshapeElision::ReshapeElision() ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
{ {
auto src_op = std::make_shared<pattern::op::Label>( auto src_op = std::make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; }); element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
...@@ -39,12 +39,9 @@ ngraph::runtime::plaidml::pass::ReshapeElision::ReshapeElision() ...@@ -39,12 +39,9 @@ ngraph::runtime::plaidml::pass::ReshapeElision::ReshapeElision()
} }
// Validate that this isn't a reordering-reshape. // Validate that this isn't a reordering-reshape.
for (std::size_t idx = 0; idx < reshape->get_input_order().size(); ++idx) if (reshape->get_is_transpose())
{ {
if (idx != reshape->get_input_order().at(idx)) return false;
{
return false;
}
} }
// Make sure that logical dimension sizes match. // Make sure that logical dimension sizes match.
......
//***************************************************************************** //*****************************************************************************
// Copyright 2017-2018 Intel Corporation // Copyright 2017-2019 Intel Corporation
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -20,23 +20,20 @@ ...@@ -20,23 +20,20 @@
namespace ngraph namespace ngraph
{ {
namespace runtime namespace pass
{ {
namespace plaidml class PrefixReshapeElimination;
{
namespace pass
{
class ReshapeElision;
}
}
} }
} }
// A minor pass to elide unnecessary reshapes. A reshape is // A pass to eliminate reshapes whose output shapes are the same as
// considered unnecessary if its output shape is the same as its input // their input shape modulo leading size-1 axes.
// shape, modulo leading size-1 axes. //
class ngraph::runtime::plaidml::pass::ReshapeElision final : public ngraph::pass::GraphRewrite // N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class ngraph::pass::PrefixReshapeElimination final : public ngraph::pass::GraphRewrite
{ {
public: public:
ReshapeElision(); PrefixReshapeElimination();
}; };
...@@ -54,7 +54,6 @@ set(SRC ...@@ -54,7 +54,6 @@ set(SRC
plaidml_pass_lower_convolutions.cpp plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp plaidml_pass_winograd.cpp
plaidml_tensor.cpp plaidml_tensor.cpp
plaidml_translate.cpp plaidml_translate.cpp
......
...@@ -31,14 +31,15 @@ ngraph::runtime::plaidml::PlaidML_Backend::PlaidML_Backend(const char* configura ...@@ -31,14 +31,15 @@ ngraph::runtime::plaidml::PlaidML_Backend::PlaidML_Backend(const char* configura
std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor( std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor(
const ngraph::element::Type& element_type, const ngraph::Shape& shape) const ngraph::element::Type& element_type, const ngraph::Shape& shape)
{ {
return std::make_shared<PlaidML_Tensor>(&m_config, element_type, shape, "direct_data", nullptr); return std::make_shared<PlaidML_Tensor>(
this, &m_config, element_type, shape, "direct_data", nullptr);
} }
std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor( std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer)
{ {
return std::make_shared<PlaidML_Tensor>( return std::make_shared<PlaidML_Tensor>(
&m_config, element_type, shape, "direct_data", memory_pointer); this, &m_config, element_type, shape, "direct_data", memory_pointer);
} }
std::shared_ptr<ngraph::Function> std::shared_ptr<ngraph::Function>
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp" #include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
...@@ -36,7 +37,6 @@ ...@@ -36,7 +37,6 @@
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace namespace
...@@ -101,9 +101,9 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -101,9 +101,9 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>(); pass_manager.register_pass<ngraph::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd) if (pass_manager.get_pass_config().get_pass_enable("Winograd"))
{ {
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
} }
......
...@@ -77,7 +77,6 @@ ngraph::runtime::plaidml::Config ...@@ -77,7 +77,6 @@ ngraph::runtime::plaidml::Config
bool help = false; bool help = false;
bool list = false; bool list = false;
bool debug = false; bool debug = false;
bool winograd = false;
std::size_t device_idx = 0; std::size_t device_idx = 0;
std::string eventlog_config; std::string eventlog_config;
std::string graphviz; std::string graphviz;
...@@ -242,14 +241,6 @@ ngraph::runtime::plaidml::Config ...@@ -242,14 +241,6 @@ ngraph::runtime::plaidml::Config
continue; continue;
} }
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options // Reject unknown options
err = true; err = true;
} }
...@@ -257,7 +248,7 @@ ngraph::runtime::plaidml::Config ...@@ -257,7 +248,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] = constexpr char help_text[] =
"PlaidML Backend Specification: \"" "PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][," "PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \"" "eventlog=<filename>][,graphviz=<filename>]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\""; "PlaidML:0,list_devices\"";
if (err) if (err)
{ {
...@@ -292,7 +283,5 @@ ngraph::runtime::plaidml::Config ...@@ -292,7 +283,5 @@ ngraph::runtime::plaidml::Config
result.graphviz = graphviz; result.graphviz = graphviz;
result.winograd = winograd;
return result; return result;
} }
...@@ -39,6 +39,5 @@ struct ngraph::runtime::plaidml::Config ...@@ -39,6 +39,5 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx; std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev; std::shared_ptr<vertexai::plaidml::device> dev;
bool debug; bool debug;
bool winograd;
std::string graphviz; std::string graphviz;
}; };
...@@ -166,7 +166,7 @@ namespace ngraph ...@@ -166,7 +166,7 @@ namespace ngraph
{ {
Impl impl; Impl impl;
impl.set_build(build); impl.set_build(build);
impl.set_op(dynamic_cast<const typename Impl::Op*>(op)); impl.set_op(static_cast<const typename Impl::Op*>(op));
impl.Apply(); impl.Apply();
} }
}; };
......
...@@ -33,7 +33,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision() ...@@ -33,7 +33,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto concat = std::dynamic_pointer_cast<ngraph::op::Concat>(m.get_match_root()); auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments(); auto args = concat->get_arguments();
// Elide one-argument concats. // Elide one-argument concats.
......
...@@ -45,7 +45,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast() ...@@ -45,7 +45,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
// for the broadcast axes. // for the broadcast axes.
auto src = m.get_matched_nodes().at(2); auto src = m.get_matched_nodes().at(2);
Shape src_shape = src->get_shape(); Shape src_shape = src->get_shape();
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1)); auto broadcast = std::static_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1));
AxisVector reshape_order; AxisVector reshape_order;
Shape reshape_shape; Shape reshape_shape;
......
...@@ -53,9 +53,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -53,9 +53,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{ {
return reshape->get_input_order(); return reshape->get_input_order();
} }
AxisVector result(node->get_shape().size()); return get_default_order(node->get_shape());
std::iota(result.begin(), result.end(), 0);
return result;
}; };
std::shared_ptr<Node> node = m.get_match_root(); std::shared_ptr<Node> node = m.get_match_root();
......
...@@ -37,8 +37,8 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination() ...@@ -37,8 +37,8 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes(); auto nodes = m.get_matched_nodes();
auto lower = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(0)); auto lower = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(1)); auto upper = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
std::vector<size_t> axes = lower->get_replication_axes(); std::vector<size_t> axes = lower->get_replication_axes();
const std::vector<size_t>& upper_axes = upper->get_replication_axes(); const std::vector<size_t>& upper_axes = upper->get_replication_axes();
auto uit = upper_axes.begin(); auto uit = upper_axes.begin();
......
...@@ -51,7 +51,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision() ...@@ -51,7 +51,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
for (auto nit = nodes.begin() + 1; nit != nodes.end(); ++nit) for (auto nit = nodes.begin() + 1; nit != nodes.end(); ++nit)
{ {
auto replicate = std::dynamic_pointer_cast<plaidml::op::Replicate>(*nit); auto replicate = std::static_pointer_cast<plaidml::op::Replicate>(*nit);
const auto& replicate_axes = replicate->get_replication_axes(); const auto& replicate_axes = replicate->get_replication_axes();
bool elidable = true; bool elidable = true;
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
......
...@@ -112,7 +112,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd() ...@@ -112,7 +112,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto conv = std::dynamic_pointer_cast<plaidml::op::Convolution>(m.get_match_root()); auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments(); NodeVector args = conv->get_arguments();
std::shared_ptr<ngraph::op::Constant> a; std::shared_ptr<ngraph::op::Constant> a;
std::shared_ptr<ngraph::op::Constant> b; std::shared_ptr<ngraph::op::Constant> b;
......
...@@ -22,12 +22,13 @@ ...@@ -22,12 +22,13 @@
namespace vp = vertexai::plaidml; namespace vp = vertexai::plaidml;
ngraph::runtime::plaidml::PlaidML_Tensor::PlaidML_Tensor(Config* config, ngraph::runtime::plaidml::PlaidML_Tensor::PlaidML_Tensor(Backend* parent,
Config* config,
const ngraph::element::Type& element_type, const ngraph::element::Type& element_type,
const ngraph::Shape& shape, const ngraph::Shape& shape,
const std::string& name, const std::string& name,
void* memory) void* memory)
: Tensor{std::make_shared<ngraph::descriptor::Tensor>(element_type, shape, name)} : Tensor{std::make_shared<ngraph::descriptor::Tensor>(element_type, shape, name), parent}
, m_tensor{config->dev->allocate( , m_tensor{config->dev->allocate(
to_plaidml(config->ctx, element_type, shape, ConversionUse::FOR_IO))} to_plaidml(config->ctx, element_type, shape, ConversionUse::FOR_IO))}
, m_memory{memory} , m_memory{memory}
......
...@@ -35,7 +35,8 @@ namespace ngraph ...@@ -35,7 +35,8 @@ namespace ngraph
class ngraph::runtime::plaidml::PlaidML_Tensor final : public ngraph::runtime::Tensor class ngraph::runtime::plaidml::PlaidML_Tensor final : public ngraph::runtime::Tensor
{ {
public: public:
PlaidML_Tensor(Config* config, PlaidML_Tensor(Backend* parent,
Config* config,
const ngraph::element::Type& element_type, const ngraph::element::Type& element_type,
const ngraph::Shape& shape, const ngraph::Shape& shape,
const std::string& name, const std::string& name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment