Commit f75e10c3 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

Improve PlaidML concat perf (#2399)

parent 4fd8228d
...@@ -48,6 +48,7 @@ set(SRC ...@@ -48,6 +48,7 @@ set(SRC
plaidml_ops_transcendental.cpp plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp plaidml_pass_concat_elision.cpp
plaidml_pass_concat_split.cpp
plaidml_pass_explicit_logicals.cpp plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp plaidml_pass_lower_convolutions.cpp
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp" #include "ngraph/runtime/plaidml/plaidml_logger.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_split.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp" #include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
...@@ -96,6 +97,7 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -96,6 +97,7 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatSplit>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>(); pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_concat_split.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
static std::size_t kMaxConcatInputs = 8;
ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit()
{
auto concat_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
auto op = dynamic_cast<ngraph::op::Concat*>(node.get());
return op != nullptr && kMaxConcatInputs < op->get_input_size();
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
while (1 < args.size())
{
NodeVector new_args;
auto b = args.begin();
auto e = args.end();
while (b != e)
{
NodeVector::iterator p;
if (e - b < kMaxConcatInputs)
{
p = e;
}
else
{
p = b + kMaxConcatInputs;
}
if (p - b == 1)
{
new_args.emplace_back(*b);
}
else
{
NodeVector sub_args;
for (auto n = b; n != p; ++n)
{
sub_args.push_back(*n);
}
new_args.emplace_back(std::make_shared<ngraph::op::Concat>(
std::move(sub_args), concat->get_concatenation_axis()));
}
b = p;
}
args = std::move(new_args);
}
replace_node(std::move(concat), args[0]);
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback));
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ConcatSplit;
}
}
}
}
// A pass to split concats.
//
// PlaidML's concat operator is remarkably inefficient. To make it
// slightly less awful, we split concats into groups.
class ngraph::runtime::plaidml::pass::ConcatSplit final : public ngraph::pass::GraphRewrite
{
public:
ConcatSplit();
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment