Unverified Commit ea29c6e3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

fuse dot(a,b) + c (#418)

cblas_gemm working on mlp

rebase & small fixes

enable debug output

support replacing function's outputs

productizing CPUFusion

addressing Bob and Jayaram's feedback

removing json used for simplification tests

adding comments

fixing formatting errors and removing dead code

TODO msg

removing serializer changes
parent d933d531
......@@ -169,6 +169,8 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/cpu_emitter.cpp
runtime/cpu/cpu_external_function.cpp
runtime/cpu/cpu_tensor_view_wrapper.cpp
runtime/cpu/ops/matmul_bias.cpp
runtime/cpu/pass/cpu_fusion.cpp
)
# LLVM binary builds are typically built without RTTI
# The built-in headers are in a version-specific directory
......
......@@ -152,3 +152,19 @@ std::list<shared_ptr<Node>> Function::get_ops() const
});
return ops;
}
void Function::replace_output_op(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
auto it = std::find(begin(m_results), end(m_results), old);
if (it != end(m_results))
{
NGRAPH_DEBUG << "Replacing output " << old->get_name() << " w/ " << repl->get_name();
*it = repl;
}
}
void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
replace_output_op(old, repl);
ngraph::replace_node(old, repl, true);
}
......@@ -78,6 +78,10 @@ namespace ngraph
size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
//updates old w/ repl in m_results list
void replace_output_op(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
//updates graph and m_results list
void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
protected:
Nodes m_results;
......
......@@ -105,12 +105,15 @@ void ngraph::free_nodes(shared_ptr<Function> p)
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool replace_output)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
if (target->is_output() && !replace_output)
{
return;
}
//fix input/output descriptors
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
......
......@@ -42,7 +42,9 @@ namespace ngraph
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
bool replace_output = false);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
......
......@@ -110,12 +110,12 @@ namespace nervana
__PRETTY_FUNCTION__) \
.stream()
//#define NGRAPH_DEBUG \
// nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \
// nervana::get_file_name(__FILE__), \
// __LINE__, \
// __PRETTY_FUNCTION__) \
// .stream()
// #define NGRAPH_DEBUG \
// nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \
// nervana::get_file_name(__FILE__), \
// __LINE__, \
// __PRETTY_FUNCTION__) \
// .stream()
#define NGRAPH_DEBUG nervana::get_nil_stream()
}
#include "graph_rewrite.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "graph_rewrite.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers)
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers,
std::shared_ptr<ngraph::Function> f)
{
bool rewritten = false;
for (auto node : nodes)
......@@ -15,23 +31,26 @@ bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
for (auto matcher : matchers)
{
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name();
if (!node->is_output() /*this restriction can be lifted when we find an use case for it*/
&&
matcher->match(node))
<< node->get_name() << " , is_output = " << node->is_output();
if (matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name();
rewritten = true;
matcher->process_match();
break; //move onto the next node
auto result = matcher->process_match();
if (result)
{
f->replace_node(node, result);
//move onto the next node
break;
}
}
}
}
return rewritten;
}
bool ngraph::pass::GraphRewrite::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
bool ngraph::pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
return run_matchers_on_nodes_list(nodes, m_matchers);
return run_matchers_on_nodes_list(f->get_ordered_ops(), m_matchers, f);
}
......@@ -40,19 +40,21 @@ namespace ngraph
/// Patterns can be added by using \sa add_matcher
/// Callbacks should use \sa replace_node to transform matched sub graphs
class ngraph::pass::GraphRewrite : public CallGraphPass
class ngraph::pass::GraphRewrite : public FunctionPass
{
public:
GraphRewrite()
: CallGraphPass()
: FunctionPass()
{
}
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
virtual bool run_on_call_graph(const std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers);
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers,
std::shared_ptr<ngraph::Function> f);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
//enable cascading rewrites
......
......@@ -63,8 +63,10 @@ namespace ngraph
auto args = get_arguments(label);
if (args.size() > 0)
{
assert(args.size() ==
1); //it should be impossible to construct labels w/ more than one arg
if (args.size() != 1)
{
throw ngraph_error("Labels can only take 1 argument!");
}
NGRAPH_DEBUG << "[MATCHER] Label describes a sub graph in the pattern";
is_match = match_node(args.at(0), graph_node, pattern_map);
}
......@@ -92,7 +94,11 @@ namespace ngraph
else
{
auto args = get_arguments(any);
assert(args.size() == 1);
if (args.size() != 1)
{
throw ngraph_error("Any can only take one argument");
}
return match_node(args.at(0), graph_node, pattern_map);
}
}
......@@ -101,7 +107,10 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
assert(pattern_node && graph_node);
if (!pattern_node || !graph_node)
{
throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!");
}
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node->get_name() << " matched "
......@@ -191,17 +200,24 @@ namespace ngraph
return false;
}
void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
std::shared_ptr<Node> Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
{
gr_callback_fn cb = m_callback;
if (callback)
{
cb = callback;
}
if (!cb)
{
throw ngraph_error("process_match invoked w/o a callback function");
}
if (!this->m_match_root)
{
throw ngraph_error("process_match invoked w/o a match");
}
assert(cb);
assert(this->m_match_root);
cb(*this);
return cb(*this);
}
static Nodes get_users(std::shared_ptr<Node> node)
......
......@@ -29,7 +29,7 @@ namespace ngraph
namespace pattern
{
using gr_callback_fn = std::function<void(class Matcher& m)>;
using gr_callback_fn = std::function<std::shared_ptr<Node>(class Matcher& m)>;
namespace op
{
......@@ -60,7 +60,7 @@ namespace ngraph
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr);
std::shared_ptr<Node> process_match(gr_callback_fn callback = nullptr);
void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
......
......@@ -42,6 +42,7 @@
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/util.hpp"
using namespace std;
......@@ -129,6 +130,62 @@ void runtime::cpu::CPU_Emitter::EmitAdd(codegen::CodeWriter& writer,
writer << "}\n";
}
//TODO: This could be further optimized to reduce the impact of memcpy by either
//a) emitting customized code for initializing output/bias
//b) emitting two cblas calls (one for gemm on W and x and the second for gemm on Bias and E^T + the result of the first gemm)
//@jbobba suggests b) is more efficient but we should benchmark both
void runtime::cpu::CPU_Emitter::EmitMatmulBias(codegen::CodeWriter& writer,
const ngraph::Node* node,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
const ngraph::op::MatmulBias* cg = static_cast<const ngraph::op::MatmulBias*>(node);
const Shape& arg0_shape = cg->get_arg0_shape(); //W
const Shape& arg1_shape = cg->get_arg1_shape(); //x
const Shape& arg2_shape = args[2].get_shape(); //bias (C)
static const char* ctranspose = "cblas::Transpose::Transpose, ";
static const char* cnotranspose = "cblas::Transpose::None, ";
size_t m = arg0_shape[0];
size_t n = arg1_shape[1];
size_t k = arg0_shape[1];
//
const char* tranpose_a = cnotranspose;
const char* tranpose_b = cnotranspose;
size_t lda = arg0_shape[1];
size_t ldb = arg1_shape[1];
if (cg->get_is_arg0_transposed())
{
tranpose_a = ctranspose;
m = arg0_shape[1];
k = arg0_shape[0];
}
if (cg->get_is_arg1_transposed())
{
tranpose_b = ctranspose;
n = arg1_shape[0];
}
writer << "{ // " << node->get_name() << "\n";
writer.indent++;
writer << "memcpy(" << out[0].get_name() << ", " << args[2].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n << ", " << k
<< ",\n"
<< " 1.0f, " << args[0].get_name() << ", " << max(1UL, lda) << ", "
<< args[1].get_name() << ", " << max(1UL, ldb) << ", 1.0f,\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1]) << ");\n";
writer.indent--;
writer << "}\n";
}
void runtime::cpu::CPU_Emitter::EmitDot(codegen::CodeWriter& writer,
const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
......
......@@ -60,6 +60,7 @@ namespace ngraph
static void EMITTER_DECL(EmitSelect);
static void EMITTER_DECL(EmitSubtract);
static void EMITTER_DECL(EmitBroadcast);
static void EMITTER_DECL(EmitMatmulBias);
static void EMITTER_DECL(EmitConvert);
static void EMITTER_DECL(EmitConstant);
static void EMITTER_DECL(EmitReshape);
......
......@@ -92,6 +92,7 @@
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
using namespace std;
......@@ -143,6 +144,7 @@ static StaticInitializers s_static_initializers;
static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::CPU_Emitter::EmitAdd},
{TI(ngraph::op::MatmulBias), &runtime::cpu::CPU_Emitter::EmitMatmulBias},
{TI(ngraph::op::Dot), &runtime::cpu::CPU_Emitter::EmitDot},
{TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::EmitMultiply},
{TI(ngraph::op::Parameter), &runtime::cpu::CPU_Emitter::EmitNop},
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "matmul_bias.hpp"
std::shared_ptr<ngraph::Node> ngraph::op::MatmulBias::copy_with_new_args(
const std::vector<std::shared_ptr<ngraph::Node>>& new_args) const
{
if (new_args.size() != 2)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<MatmulBias>(new_args.at(0),
new_args.at(1),
new_args.at(1),
m_shape_w,
m_shape_x,
m_transpose_w,
m_transpose_x);
}
ngraph::op::MatmulBias::MatmulBias(std::shared_ptr<ngraph::Node> W,
std::shared_ptr<ngraph::Node> x,
std::shared_ptr<ngraph::Node> b,
Shape shape_w,
Shape shape_x,
bool transpose_w,
bool transpose_x)
: RequiresTensorViewArgs("CblassGemm", {W, x, b})
, m_shape_w(shape_w)
, m_shape_x(shape_x)
, m_transpose_w(transpose_w)
, m_transpose_x(transpose_x)
{
if (shape_w.size() != 2)
{
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w);
throw ngraph_error("W.shape.rank != 2 while creating MatmulBias");
}
if (shape_x.size() != 2)
{
NGRAPH_DEBUG << "x shape = " << vector_to_string(shape_x);
throw ngraph_error("x.shape.rank != 2 while creating MatmulBias");
}
size_t dot_dimension_w = (transpose_w) ? 0 : 1;
size_t dot_dimension_x = (transpose_x) ? 1 : 0;
NGRAPH_DEBUG << "dot_dimension_w = " << dot_dimension_w
<< " , dot_dimension_x = " << dot_dimension_x;
NGRAPH_DEBUG << "W shape = " << vector_to_string(shape_w)
<< " , x shape = " << vector_to_string(shape_x);
if (shape_w.at(dot_dimension_w) != shape_x.at(dot_dimension_x))
{
throw ngraph_error("product dimensions are not equal while creating MatmulBias");
}
auto dot_shape = Shape{shape_w.at(1 - dot_dimension_w), shape_x.at(1 - dot_dimension_x)};
NGRAPH_DEBUG << "dot_shape shape = " << vector_to_string(dot_shape)
<< " , b shape = " << vector_to_string(b->get_shape());
add_output(W->get_element_type(), dot_shape);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/util.hpp"
#include <memory>
namespace ngraph
{
namespace op
{
class MatmulBias : public RequiresTensorViewArgs
{
public:
MatmulBias(std::shared_ptr<Node> W,
std::shared_ptr<Node> x,
std::shared_ptr<Node> b,
Shape shape_w,
Shape shape_x,
bool transpose_w,
bool transpose_x);
bool get_is_arg0_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override;
private:
Shape m_shape_w;
Shape m_shape_x;
bool m_transpose_w;
bool m_transpose_x;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "cpu_fusion.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
static bool init_cblas_arg(std::shared_ptr<ngraph::Node> reshape,
std::shared_ptr<ngraph::Node> arg,
bool& transpose_w,
ngraph::Shape& shape_w)
{
auto r_w = std::dynamic_pointer_cast<ngraph::op::Reshape>(reshape);
if (!r_w)
{
return true; //nth to do; reshape isn't a reshape
}
if (r_w->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " doesn't reshape into matrix"
<< ngraph::vector_to_string(r_w->get_shape());
return false;
}
auto io = r_w->get_input_order();
if (r_w->get_shape().size() != arg->get_shape().size()) //reshape
{
ngraph::AxisVector dio(io.size());
std::iota(begin(dio), end(dio), 0);
if (io != dio) //we can't reshape and transpose at the same time
{
NGRAPH_DEBUG << "Reshape for " << reshape->get_name() << " is not in default order "
<< ngraph::vector_to_string(io);
NGRAPH_DEBUG << "r_w shape = " << ngraph::vector_to_string(r_w->get_shape());
NGRAPH_DEBUG << "arg shape = " << ngraph::vector_to_string(arg->get_shape());
return false;
}
shape_w = r_w->get_shape();
}
else
{
if (io == ngraph::AxisVector{1, 0})
{
transpose_w = true;
}
//otherwise no-op reshape
}
return true;
}
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::pass::CPUFusion::construct_gemm_pattern()
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto shape_dot = Shape{2, 1};
auto W = std::make_shared<pattern::op::Label>(element::f32, shape_w);
auto x = std::make_shared<pattern::op::Label>(element::f32, shape_x);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = std::make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(element::f32, shape_b);
auto pbroadcast = std::make_shared<op::Broadcast>(b, shape_dot, AxisSet{0});
auto padd = pdot + pbroadcast;
ngraph::pattern::gr_callback_fn callback = [W, x, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_gemm_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<Node> nn = nullptr;
auto mpattern = m.match_root();
if (mpattern->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
return nn;
}
auto dot = mpattern->get_input_op(0);
if (dot->get_shape().size() != 2)
{
NGRAPH_DEBUG << "dot = " << dot->get_name() << " shape is not equal to 2!";
return nn;
}
bool transpose_w = false;
Shape shape_arg0{pattern_map[W]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(0), pattern_map[W], transpose_w, shape_arg0))
{
return nn;
}
bool transpose_x = false;
Shape shape_arg1{pattern_map[x]->get_shape()};
if (!init_cblas_arg(dot->get_input_op(1), pattern_map[x], transpose_x, shape_arg1))
{
return nn;
}
auto cg = std::shared_ptr<Node>(new op::MatmulBias(pattern_map[W],
pattern_map[x],
mpattern->get_input_op(1),
shape_arg0,
shape_arg1,
transpose_w,
transpose_x));
return cg;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(padd, callback);
this->add_matcher(m);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace pass
{
class CPUFusion;
}
}
class ngraph::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
: GraphRewrite()
{
construct_gemm_pattern();
}
private:
void construct_gemm_pattern();
};
......@@ -69,7 +69,7 @@ endif()
if(NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR)
include_directories(SYSTEM ${LLVM_INCLUDE_DIR})
link_directories(${LLVM_LIB_DIR})
set(SRC ${SRC} backend_performance.cpp codegen.cpp)
set(SRC ${SRC} backend_performance.cpp codegen.cpp cpu_fusion.cpp)
set(BACKEND_NAMES ${BACKEND_NAMES} "CPU")
endif()
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
//
#include "ngraph/file_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(cpu_fusion, gemm_pattern)
{
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto W = std::make_shared<pattern::op::Label>(A);
auto x = std::make_shared<pattern::op::Label>(B);
auto reshape_pred = [](std::shared_ptr<Node> n) {
return static_cast<bool>(std::dynamic_pointer_cast<op::Reshape>(n));
};
auto skip_w = std::make_shared<pattern::op::Any>(W, reshape_pred);
auto skip_x = std::make_shared<pattern::op::Any>(x, reshape_pred);
auto pdot = make_shared<op::Dot>(skip_w, skip_x);
auto b = std::make_shared<pattern::op::Label>(C);
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, W->get_shape());
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, x->get_shape());
auto re_dot = make_shared<op::Dot>(reshape_w, reshape_x);
auto re_add = re_dot + broadcast;
ASSERT_TRUE(n.match(padd, re_add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
ASSERT_EQ(n.get_pattern_map()[b], C);
auto cg =
make_shared<op::MatmulBias>(W, x, broadcast, W->get_shape(), x->get_shape(), false, false);
}
TEST(cpu_fusion, gemm_cpu)
{
auto shapeA = Shape{3, 2};
auto shapeB = Shape{2, 3};
auto shapeC = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shapeA);
auto B = make_shared<op::Parameter>(element::f32, shapeB);
auto reshape_w = make_shared<op::Reshape>(A, AxisVector{1, 0}, Shape{2, 3});
auto reshape_x = make_shared<op::Reshape>(B, AxisVector{1, 0}, Shape{3, 2});
auto one = op::Constant::create<float>(element::f32, Shape{}, std::vector<float>{1.0f});
auto broadcast = make_shared<op::Broadcast>(one, shapeC, AxisSet{0, 1});
auto cg =
make_shared<op::MatmulBias>(A, B, broadcast, A->get_shape(), B->get_shape(), true, true);
auto f = make_shared<Function>(cg, op::Parameters{A, B});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shapeA);
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shapeB);
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, shapeC);
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{3.0f, 3.0f, 3.0f, 9.0f, 9.0f, 9.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{10, 28, 37, 109};
ASSERT_TRUE(read_vector<float>(result) == expected);
}
TEST(cpu_fusion, cpu_fusion_pass_basic)
{
auto shape = Shape{};
auto shape_w = Shape{2, 4};
auto shape_x = Shape{4, 1};
auto shape_b = Shape{1};
auto A = make_shared<op::Parameter>(element::f32, shape_w);
auto B = make_shared<op::Parameter>(element::f32, shape_x);
auto C = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto broadcast = make_shared<op::Broadcast>(C, dot->get_shape(), AxisSet{0});
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
auto func = make_shared<Function>(graph, op::Parameters{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_input_op(0)), nullptr);
}
TEST(cpu_fusion, gemm_mlp)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/mnist_mlp_forward.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ccg = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(ccg, 3);
}
......@@ -28,47 +28,11 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "util/matcher.hpp"
using namespace ngraph;
using namespace std;
//this is for more nuanced testing
class TestMatcher : public pattern::Matcher
{
using pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
return this->pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<Node>& pattern_node, const std::shared_ptr<Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
const std::string& init_val,
......@@ -181,13 +145,13 @@ public:
auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_TRUE(const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return;
return nn;
}
auto const_values = const_node->get_vector<int32_t>();
......@@ -197,9 +161,9 @@ public:
if (!all_ones)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return;
return nn;
}
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return pattern_map[pattern];
};
auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
......@@ -212,7 +176,7 @@ public:
auto iconst0 = construct_constant_node(0);
auto pattern = std::make_shared<pattern::op::Label>(iconst0);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2);
......@@ -225,13 +189,15 @@ public:
auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_NE(nullptr, const_node);
//ASSERT_NE(nullptr, const_node);
std::shared_ptr<ngraph::Node> nn = nullptr;
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return;
return nn;
}
auto const_values = const_node->get_vector<int>();
......@@ -241,10 +207,10 @@ public:
if (!all_zeros)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return;
return nn;
}
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
return pattern_map[pattern];
};
auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
......@@ -261,8 +227,9 @@ public:
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = std::make_shared<op::Sum>(reducee, reduce->get_reduction_axes());
ngraph::replace_node(reduce, sum);
auto sum =
std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
return sum;
};
auto m = make_shared<TestMatcher>(sum_pattern, callback);
......@@ -290,9 +257,27 @@ TEST(pattern, graph_rewrite)
{
auto shape = Shape{};
pass::Manager pass_manager;
pass_manager.register_pass<TestGraphRewrite>();
{
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = construct_constant_node(0);
auto graph_a = a + iconst0;
auto graph_b = b + iconst0;
auto f = std::make_shared<Function>(ngraph::Nodes{a, b, graph_a, c, graph_b},
op::Parameters{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(graph_a->get_output_inputs(0).empty());
ASSERT_TRUE(graph_b->get_output_inputs(0).empty());
auto expected = ngraph::Nodes{a, b, a, c, b};
ASSERT_TRUE(f->get_results() == expected);
}
{
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
{
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
size_t count = 0;
for (auto op : f->get_ops())
{
if (std::dynamic_pointer_cast<T>(op))
{
count++;
}
}
return count;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment