Commit f79b40a7 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

Minor PlaidML updates (#2283)

* Use static cast where possible

* Tensor API update

* Move prefix reshape elision to be a general pass

* Use pass config to select Winograd optimization

* Use get_is_transpose() to detect transposes

* Use get_default_order to build AxisVectors
parent d74ea190
......@@ -147,6 +147,7 @@ set (SRC
pass/nop_elimination.cpp
pass/pass.cpp
pass/pass_config.cpp
pass/prefix_reshape_elimination.cpp
pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp
pass/reshape_sinking.cpp
......
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
......@@ -24,7 +24,7 @@
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
ngraph::runtime::plaidml::pass::ReshapeElision::ReshapeElision()
ngraph::pass::PrefixReshapeElimination::PrefixReshapeElimination()
{
auto src_op = std::make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
......@@ -39,13 +39,10 @@ ngraph::runtime::plaidml::pass::ReshapeElision::ReshapeElision()
}
// Validate that this isn't a reordering-reshape.
for (std::size_t idx = 0; idx < reshape->get_input_order().size(); ++idx)
{
if (idx != reshape->get_input_order().at(idx))
if (reshape->get_is_transpose())
{
return false;
}
}
// Make sure that logical dimension sizes match.
const Shape& src_shape = reshape->get_input_shape(0);
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -20,23 +20,20 @@
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ReshapeElision;
}
}
class PrefixReshapeElimination;
}
}
// A minor pass to elide unnecessary reshapes. A reshape is
// considered unnecessary if its output shape is the same as its input
// shape, modulo leading size-1 axes.
class ngraph::runtime::plaidml::pass::ReshapeElision final : public ngraph::pass::GraphRewrite
// A pass to eliminate reshapes whose output shapes are the same as
// their input shape modulo leading size-1 axes.
//
// N.B. This pass MUST only be used by backends that can handle the
// omission of leading size-1 axes, e.g. backends that implement
// NumPy-style broadcast semantics.
class ngraph::pass::PrefixReshapeElimination final : public ngraph::pass::GraphRewrite
{
public:
ReshapeElision();
PrefixReshapeElimination();
};
......@@ -54,7 +54,6 @@ set(SRC
plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp
plaidml_tensor.cpp
plaidml_translate.cpp
......
......@@ -31,14 +31,15 @@ ngraph::runtime::plaidml::PlaidML_Backend::PlaidML_Backend(const char* configura
std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor(
const ngraph::element::Type& element_type, const ngraph::Shape& shape)
{
return std::make_shared<PlaidML_Tensor>(&m_config, element_type, shape, "direct_data", nullptr);
return std::make_shared<PlaidML_Tensor>(
this, &m_config, element_type, shape, "direct_data", nullptr);
}
std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backend::create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer)
{
return std::make_shared<PlaidML_Tensor>(
&m_config, element_type, shape, "direct_data", memory_pointer);
this, &m_config, element_type, shape, "direct_data", memory_pointer);
}
std::shared_ptr<ngraph::Function>
......
......@@ -26,6 +26,7 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/prefix_reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
......@@ -36,7 +37,6 @@
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace
......@@ -101,9 +101,9 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>();
pass_manager.register_pass<ngraph::pass::PrefixReshapeElimination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd)
if (pass_manager.get_pass_config().get_pass_enable("Winograd"))
{
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
}
......
......@@ -77,7 +77,6 @@ ngraph::runtime::plaidml::Config
bool help = false;
bool list = false;
bool debug = false;
bool winograd = false;
std::size_t device_idx = 0;
std::string eventlog_config;
std::string graphviz;
......@@ -242,14 +241,6 @@ ngraph::runtime::plaidml::Config
continue;
}
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options
err = true;
}
......@@ -257,7 +248,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] =
"PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \""
"eventlog=<filename>][,graphviz=<filename>]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\"";
if (err)
{
......@@ -292,7 +283,5 @@ ngraph::runtime::plaidml::Config
result.graphviz = graphviz;
result.winograd = winograd;
return result;
}
......@@ -39,6 +39,5 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev;
bool debug;
bool winograd;
std::string graphviz;
};
......@@ -166,7 +166,7 @@ namespace ngraph
{
Impl impl;
impl.set_build(build);
impl.set_op(dynamic_cast<const typename Impl::Op*>(op));
impl.set_op(static_cast<const typename Impl::Op*>(op));
impl.Apply();
}
};
......
......@@ -33,7 +33,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto concat = std::dynamic_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
// Elide one-argument concats.
......
......@@ -45,7 +45,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
// for the broadcast axes.
auto src = m.get_matched_nodes().at(2);
Shape src_shape = src->get_shape();
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1));
auto broadcast = std::static_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1));
AxisVector reshape_order;
Shape reshape_shape;
......
......@@ -53,9 +53,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{
return reshape->get_input_order();
}
AxisVector result(node->get_shape().size());
std::iota(result.begin(), result.end(), 0);
return result;
return get_default_order(node->get_shape());
};
std::shared_ptr<Node> node = m.get_match_root();
......
......@@ -37,8 +37,8 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes();
auto lower = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
auto lower = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
std::vector<size_t> axes = lower->get_replication_axes();
const std::vector<size_t>& upper_axes = upper->get_replication_axes();
auto uit = upper_axes.begin();
......
......@@ -51,7 +51,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
for (auto nit = nodes.begin() + 1; nit != nodes.end(); ++nit)
{
auto replicate = std::dynamic_pointer_cast<plaidml::op::Replicate>(*nit);
auto replicate = std::static_pointer_cast<plaidml::op::Replicate>(*nit);
const auto& replicate_axes = replicate->get_replication_axes();
bool elidable = true;
for (std::size_t idx = 0; idx < dim_limit; ++idx)
......
......@@ -112,7 +112,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto conv = std::dynamic_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments();
std::shared_ptr<ngraph::op::Constant> a;
std::shared_ptr<ngraph::op::Constant> b;
......
......@@ -22,12 +22,13 @@
namespace vp = vertexai::plaidml;
ngraph::runtime::plaidml::PlaidML_Tensor::PlaidML_Tensor(Config* config,
ngraph::runtime::plaidml::PlaidML_Tensor::PlaidML_Tensor(Backend* parent,
Config* config,
const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
const std::string& name,
void* memory)
: Tensor{std::make_shared<ngraph::descriptor::Tensor>(element_type, shape, name)}
: Tensor{std::make_shared<ngraph::descriptor::Tensor>(element_type, shape, name), parent}
, m_tensor{config->dev->allocate(
to_plaidml(config->ctx, element_type, shape, ConversionUse::FOR_IO))}
, m_memory{memory}
......
......@@ -35,7 +35,8 @@ namespace ngraph
class ngraph::runtime::plaidml::PlaidML_Tensor final : public ngraph::runtime::Tensor
{
public:
PlaidML_Tensor(Config* config,
PlaidML_Tensor(Backend* parent,
Config* config,
const ngraph::element::Type& element_type,
const ngraph::Shape& shape,
const std::string& name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment