Commit 8cb48d37 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Weight Fusion (#853)

* CPU weight fusion initial version

* add tests for weight_fusion

* address @jbobba's feedback

* before cleaning up convolution_weight_optimization.cpp

* clean up, rename, fix perms, fix format
parent 3562da83
......@@ -223,6 +223,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/pass/cpu_layout.cpp
runtime/cpu/pass/cpu_nop_elimination.cpp
runtime/cpu/pass/cpu_rnn_mat_fusion.cpp
runtime/cpu/pass/cpu_post_layout_optimizations.cpp
)
# LLVM binary builds are typically built without RTTI
# The built-in headers are in a version-specific directory
......
......@@ -120,6 +120,7 @@
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_nop_elimination.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp"
......@@ -307,6 +308,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>();
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <typeindex>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fusion()
{
auto param = std::make_shared<pattern::op::Label>(element::f32, Shape{64});
auto reshape_conv =
std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto data_conv = std::make_shared<pattern::op::Label>(element::f32, Shape{16, 4, 7, 7});
auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get();
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt, AxisVector{0, 1, 2, 3});
auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc);
auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_weight against " << m.match_root()->get_name();
auto m_cvt_lt = m.match_root()->get_argument(1);
auto m_reshape_conv = m_cvt_lt->get_argument(0);
std::shared_ptr<Node> m_conv_bprop;
std::vector<std::type_index> user_pattern = {TI(ngraph::op::Reshape),
TI(runtime::cpu::op::ConvertLayout),
TI(ngraph::op::ConvolutionBackpropData)};
for (auto u : m.get_pattern_map()[param]->get_users())
{
if (u != m_reshape_conv)
{
size_t num_matches = 0;
auto ui = u;
for (; num_matches < user_pattern.size(); num_matches++)
{
const Node& user_ref = *ui;
if (TI(user_ref) != user_pattern.at(num_matches))
{
NGRAPH_DEBUG << "the type for user " << ui->get_name()
<< " doesn't match the type at " << num_matches;
break;
}
if (ui->get_users().size() != 1)
{
NGRAPH_DEBUG << u->get_name() << " has more than one user";
break;
}
ui = ui->get_users().at(0);
}
if (num_matches == user_pattern.size())
{
m_conv_bprop = u->get_users().at(0)->get_users().at(0);
NGRAPH_DEBUG << " m_conv_bprop is set to " << m_conv_bprop->get_name();
break;
}
}
}
if (!m_conv_bprop)
{
return false;
}
auto m_cvt_lt_bprop = m_conv_bprop->get_argument(0);
auto m_reshape_bprop = m_cvt_lt_bprop->get_argument(0);
NGRAPH_DEBUG << "Replacing input "
<< m_cvt_lt_bprop->get_inputs().at(0).get_output().get_node()->get_name()
<< " to " << m_cvt_lt_bprop->get_name() << " with "
<< m_cvt_lt->get_outputs().at(0).get_node()->get_name();
m_cvt_lt_bprop->get_inputs().at(0).replace_output(m_cvt_lt->get_outputs().at(0));
return true;
};
auto m = make_shared<pattern::Matcher>(conv, callback);
this->add_matcher(m);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUPostLayoutOptimizations;
}
}
}
}
class ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations : public ngraph::pass::GraphRewrite
{
public:
CPUPostLayoutOptimizations()
: GraphRewrite()
{
construct_weight_fusion();
}
void construct_weight_fusion();
};
......@@ -38,12 +38,15 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_mat_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
......@@ -1025,3 +1028,50 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
auto mmbs = get_ops_of_type<op::MatmulBias>(func);
ASSERT_TRUE(std::any_of(begin(mmbs), end(mmbs), mmb_predicate));
}
TEST(cpu_fusion, weight_fusion)
{
auto param = std::make_shared<op::Parameter>(element::f32, Shape{64});
auto reshape_conv =
std::make_shared<ngraph::op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto data_conv = std::make_shared<op::Parameter>(element::f32, Shape{16, 4, 7, 7});
auto tvt = reshape_conv->get_outputs().at(0).get_tensor_view().get();
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt, AxisVector{0, 1, 2, 3});
auto cvt_lt_conv = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv, lt_desc);
auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
auto reshape_conv_bprop =
std::make_shared<op::Reshape>(param, AxisVector{0}, Shape{16, 4, 1, 1});
auto dummy_arg_conv_bprop = std::make_shared<op::Parameter>(element::f32, Shape{1, 16, 7, 7});
auto tvt_bprop = reshape_conv_bprop->get_outputs().at(0).get_tensor_view().get();
auto lt_desc_bprop =
std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt_bprop, AxisVector{0, 1, 2, 3});
auto cvt_lt_conv_bprop =
std::make_shared<runtime::cpu::op::ConvertLayout>(reshape_conv_bprop, lt_desc_bprop);
auto conv_bprop = std::make_shared<op::ConvolutionBackpropData>(Shape{1, 4, 7, 7},
cvt_lt_conv_bprop,
dummy_arg_conv_bprop,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto conv_relu = std::make_shared<op::Relu>(conv);
auto conv_bprop_abs = std::make_shared<op::Abs>(conv_bprop);
auto f = make_shared<Function>(NodeVector{conv_relu, conv_bprop_abs},
op::ParameterVector{param, data_conv, dummy_arg_conv_bprop});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.run_passes(f);
auto new_conv_bprop_data = conv_bprop_abs->get_argument(0);
auto new_convert_layout = new_conv_bprop_data->get_argument(0);
ASSERT_EQ(std::dynamic_pointer_cast<runtime::cpu::op::ConvertLayout>(
new_convert_layout->get_argument(0)),
cvt_lt_conv);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment