Commit 18e41513 authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

Pruthvi/rnn fusion (#1677)

* WIP input * weights rnn optimization

* concat + slcing + replacing new node works

* WIP unit test case of fusing rnn inputs

* - Added unit test case for fusing rnn input weights
- registered CPURnnMatFusion_v1/v2 in codegen and DEX

* fixed redeclaration of a variable

* Refactored rnn input traformation passes into a single pass

* Refactored CPURnnMatFusion call back functions

* change random generator range to include -ve values in unit test

* address PR comments

* dont fuse if the shape of the data slices dont match
parent ee712ae8
...@@ -1000,11 +1000,12 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1000,11 +1000,12 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing // TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests. // failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>(); pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>(); pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
// pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>(); pass_manager.register_pass<runtime::cpu::pass::MultiLayerRNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>(); // pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
......
...@@ -47,6 +47,17 @@ struct Type ...@@ -47,6 +47,17 @@ struct Type
}; };
}; };
//constructs (x*W + bias)
static std::shared_ptr<pattern::Matcher>
construct_rnn_input_linear_transformation(std::shared_ptr<pattern::op::Label> labels[])
{
auto skip =
std::make_shared<pattern::op::Skip>(labels[Type::DATA], pattern::has_class<op::Reshape>());
auto dot = std::make_shared<op::Dot>(skip, labels[Type::WEIGHTS]);
auto add_bias = std::make_shared<op::Add>(dot, labels[Type::BIAS]);
return std::make_shared<pattern::Matcher>(add_bias);
}
static std::shared_ptr<Node> construct_data_pattern(std::shared_ptr<pattern::op::Label> data_slice) static std::shared_ptr<Node> construct_data_pattern(std::shared_ptr<pattern::op::Label> data_slice)
{ {
auto reshape_slice = auto reshape_slice =
...@@ -75,55 +86,81 @@ static std::shared_ptr<Node> ...@@ -75,55 +86,81 @@ static std::shared_ptr<Node>
bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Function> function) bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Function> function)
{ {
bool modified = false; bool modify_graph = false;
auto data_pred = [](std::shared_ptr<Node> n) { //--------------------------------------------------------
return std::dynamic_pointer_cast<op::Slice>(n) != nullptr; // Construct pattern version_1 for RNN input linear transformation
}; auto data_slice = std::make_shared<pattern::op::Label>(
auto data_slice = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 4}, data_pred); element::f32, Shape{1, 2, 4}, pattern::has_class<op::Slice>());
auto data_pattern = construct_data_pattern(data_slice); auto data_pattern = construct_data_pattern(data_slice);
auto weights_pred = [](std::shared_ptr<Node> n) { auto weights_reshape = std::make_shared<pattern::op::Label>(
return std::dynamic_pointer_cast<op::Reshape>(n) != nullptr; element::f32, Shape{4, 1}, pattern::has_class<op::Reshape>());
};
auto weights_reshape =
std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1}, weights_pred);
auto weights_pattern = construct_weights_pattern(weights_reshape); auto weights_pattern = construct_weights_pattern(weights_reshape);
// we don't really need a broadcast node but // we don't really need a broadcast node but
// labelling a Broadcast allows us to extract // labelling a Broadcast allows us to extract
// params from all 3 labels in the same fashion // params from all 3 labels in the same fashion
//(i.e. via get_argument(0)) //(i.e. via get_argument(0))
auto broadcast_pred = [](std::shared_ptr<Node> n) { auto bias_broadcast = std::make_shared<pattern::op::Label>(
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr; element::f32, Shape{2, 1}, pattern::has_class<op::Broadcast>());
};
auto bias_broadcast =
std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, broadcast_pred);
auto bias_pattern = construct_bias_pattern(bias_broadcast); auto bias_pattern = construct_bias_pattern(bias_broadcast);
const size_t NUM_MMB_ARGS = 3; const size_t NUM_MMB_ARGS = 3;
std::shared_ptr<pattern::op::Label> labels[] = {data_slice, weights_reshape, bias_broadcast}; std::shared_ptr<pattern::op::Label> labels_v1[] = {data_slice, weights_reshape, bias_broadcast};
// Matchers' ordering is important! Don't change! // Matchers' ordering is important! Don't change!
std::shared_ptr<pattern::Matcher> matchers[] = { std::shared_ptr<pattern::Matcher> matchers_v1[] = {
std::make_shared<pattern::Matcher>(data_pattern), std::make_shared<pattern::Matcher>(data_pattern),
std::make_shared<pattern::Matcher>(weights_pattern), std::make_shared<pattern::Matcher>(weights_pattern),
std::make_shared<pattern::Matcher>(bias_pattern)}; std::make_shared<pattern::Matcher>(bias_pattern)};
// this DS will be used to hold the matched attributes from matchers_v1
std::map<std::shared_ptr<Node>, NodeVector> op_seg_map; // add to list of params std::map<std::shared_ptr<Node>, NodeVector> op_seg_map; // add to list of params
std::map<NodeVector, NodeVector> param_list; std::map<NodeVector, NodeVector> param_list;
//--------------------------------------------------------
//--------------------------------------------------------
// Construct pattern version_2 for RNN input linear transformation
auto input_data = std::make_shared<pattern::op::Label>(
element::f32, Shape{10, 50}, pattern::has_class<op::Parameter>());
auto W = std::make_shared<pattern::op::Label>(
element::f32, Shape{50, 400}, pattern::has_class<op::Reshape>());
auto b = std::make_shared<pattern::op::Label>(
element::f32, Shape{10, 400}, pattern::has_class<op::Broadcast>());
std::shared_ptr<pattern::op::Label> labels_v2[] = {input_data, W, b};
auto matcher_v2 = construct_rnn_input_linear_transformation(labels_v2);
// this DS will be used to hold the matched attributes from matcher_v2
std::map<std::shared_ptr<Node>, NodeVector> map_weights_to_pattern;
std::map<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>, NodeVector>
map_weights_bias_to_data;
//--------------------------------------------------------
for (auto n : function->get_ordered_ops()) for (auto n : function->get_ordered_ops())
{ {
NodeVector params; NodeVector params;
NodeVector matched_nodes; NodeVector matched_nodes;
// checks if the graph matches to pattern defined in the matcher_v2
if (matcher_v2->match(n))
{
auto matched_weight = matcher_v2->get_pattern_map()[W]->get_argument(0);
auto matched_data = matcher_v2->get_pattern_map()[input_data];
auto matched_bias = matcher_v2->get_pattern_map()[b]->get_argument(0);
map_weights_to_pattern[matched_weight].push_back(matcher_v2->get_match_root());
map_weights_bias_to_data[std::make_pair(matched_weight, matched_bias)].push_back(
matched_data);
}
for (size_t i = 0; i < NUM_MMB_ARGS; i++) for (size_t i = 0; i < NUM_MMB_ARGS; i++)
{ {
auto matcher = matchers[i]; auto matcher = matchers_v1[i];
if (matcher->match(n)) if (matcher->match(n))
{ {
// if we get all 3 matches they will all fall // if we get all 3 matches they will all fall
// in the right spots (e.g. DATA, WEIGHTS, BIAS) since matchers are ordered // in the right spots (e.g. DATA, WEIGHTS, BIAS) since matchers are ordered
// if we have less than 3 matches we skip this node anyways // if we have less than 3 matches we skip this node anyways
auto matched = matcher->get_pattern_map()[labels[i]]; auto matched = matcher->get_pattern_map()[labels_v1[i]];
params.push_back(matched->get_argument(0)); params.push_back(matched->get_argument(0));
matched_nodes.push_back(matched); matched_nodes.push_back(matched);
} }
...@@ -152,54 +189,134 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -152,54 +189,134 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
} }
} }
// Expecting input data shape D=[x, y, z], weights W=[u, v], bias B = [w] auto callback_matcher_v2 = [&]() -> void {
// where y is the time step. We are computing R=dot(D,W)=[x,y,v]. We can reshape D to D'=[x*y, z], then we have dot(D',W), result // fuse the input vector to a matrix
// in R=[x*y, v], then add(R,B). We need to slice the result by strided by time steps. for (auto& it : map_weights_bias_to_data)
// iterate each unique set of parameters, replace original operations {
for (auto& p : param_list) auto weights = it.first.first;
{ auto bias = it.first.second;
NodeVector params = p.first;
NodeVector& op_nodes = p.second; if (map_weights_to_pattern[weights].size() !=
map_weights_bias_to_data[std::make_pair(weights, bias)].size())
auto data_node = params.at(Type::DATA); {
auto weights_node = params.at(Type::WEIGHTS); NGRAPH_DEBUG << "number of input data param's doesnt match the number of matched "
auto bias_node = params.at(Type::BIAS); "pattern root "
<< "nodes";
const auto& data_shape = data_node->get_shape(); return;
// construct new op nodes }
auto data_order = ngraph::get_default_order(data_node->get_shape()); auto& w_shape = weights->get_shape();
auto data_reshape_node = std::make_shared<op::Reshape>( if (w_shape.size() != 2)
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]}); {
NGRAPH_DEBUG << "weights shape for linear transformation of input is not 2D";
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS); return;
auto weights_reshape_node = old_weights_reshape_node->copy_with_new_args({weights_node}); }
auto dot_node = std::make_shared<op::Dot>(data_reshape_node, weights_reshape_node);
const auto& dot_shape = dot_node->get_shape(); auto& data_param_nodes = it.second;
// we will not fuse if the batch_size are not same across all inputs of time step
auto bias_broadcast_node = for (auto& node : data_param_nodes)
std::make_shared<op::Broadcast>(bias_node, dot_shape, AxisSet{0}); {
auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node); if (shape_size(data_param_nodes[0]->get_shape()) != shape_size(node->get_shape()))
const auto& add_shape = add_node->get_shape(); {
return;
// create a slice for each user of the dot op matching the original dot op's output }
for (auto op : op_nodes) }
// now concat the parameter hashed to the same weights
auto concated_data = std::make_shared<op::Concat>(data_param_nodes, 0);
auto& data_shape = concated_data->get_shape();
auto data_order = ngraph::get_default_order(concated_data->get_shape());
// insert reshape on the concated data to make it 2D, if its 3D
std::shared_ptr<Node> input_reshape_node = nullptr;
if (data_shape.size() == 3)
{
input_reshape_node = std::make_shared<op::Reshape>(
concated_data, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
}
auto new_input_node = data_shape.size() == 2 ? concated_data : input_reshape_node;
auto w_reshape_node = std::make_shared<op::Reshape>(
weights, AxisVector{1, 0}, Shape{w_shape[1], w_shape[0]});
auto new_dot = std::make_shared<op::Dot>(new_input_node, w_reshape_node);
auto bias_broadcast_node =
std::make_shared<op::Broadcast>(bias, new_dot->get_shape(), AxisSet{0});
auto new_add_bias = std::make_shared<op::Add>(new_dot, bias_broadcast_node);
// now slice the new_add and feed the corrosponding root nodes
auto batch_size = new_add_bias->get_shape()[0] / data_param_nodes.size();
auto shape_axis_1 = new_add_bias->get_shape()[1];
size_t start_index = 0;
size_t end_index = batch_size;
for (auto& matched_root_node : map_weights_to_pattern[weights])
{
auto slice_node = std::make_shared<op::Slice>(
new_add_bias, Coordinate{start_index, 0}, Coordinate{end_index, shape_axis_1});
start_index += batch_size;
end_index += batch_size;
NGRAPH_DEBUG << "Replacing op " << matched_root_node->get_name() << " with "
<< slice_node->get_name() << std::endl;
function->replace_node(matched_root_node, slice_node);
}
modify_graph = true;
}
};
auto callback_matcher_v1 = [&]() -> void {
// Expecting input data shape D=[x, y, z], weights W=[u, v], bias B = [w]
// where y is the time step. We are computing R=dot(D,W)=[x,y,v]. We can reshape D to D'=[x*y, z], then we have dot(D',W), result
// in R=[x*y, v], then add(R,B). We need to slice the result by strided by time steps.
// iterate each unique set of parameters, replace original operations
for (auto& p : param_list)
{ {
const auto old_slice = NodeVector params = p.first;
std::dynamic_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA)); NodeVector& op_nodes = p.second;
const auto& old_lower_bounds = old_slice->get_lower_bounds();
// lower bound matching the current time step auto data_node = params.at(Type::DATA);
const Coordinate lower_bounds{old_lower_bounds[1], 0}; auto weights_node = params.at(Type::WEIGHTS);
// striding by the number of data auto bias_node = params.at(Type::BIAS);
const Strides strides{data_shape[1], 1};
auto slice_node = const auto& data_shape = data_node->get_shape();
std::make_shared<op::Slice>(add_node, lower_bounds, add_shape, strides); // construct new op nodes
auto data_order = ngraph::get_default_order(data_node->get_shape());
// replace old nodes auto data_reshape_node = std::make_shared<op::Reshape>(
function->replace_node(op, slice_node); data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS);
auto weights_reshape_node =
old_weights_reshape_node->copy_with_new_args({weights_node});
auto dot_node = std::make_shared<op::Dot>(data_reshape_node, weights_reshape_node);
const auto& dot_shape = dot_node->get_shape();
auto bias_broadcast_node =
std::make_shared<op::Broadcast>(bias_node, dot_shape, AxisSet{0});
auto add_node = std::make_shared<op::Add>(dot_node, bias_broadcast_node);
const auto& add_shape = add_node->get_shape();
// create a slice for each user of the dot op matching the original dot op's output
for (auto op : op_nodes)
{
const auto old_slice =
std::dynamic_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA));
const auto& old_lower_bounds = old_slice->get_lower_bounds();
// lower bound matching the current time step
const Coordinate lower_bounds{old_lower_bounds[1], 0};
// striding by the number of data
const Strides strides{data_shape[1], 1};
auto slice_node =
std::make_shared<op::Slice>(add_node, lower_bounds, add_shape, strides);
// replace old nodes
function->replace_node(op, slice_node);
}
modify_graph = true;
} }
modified = true; };
}
return modified; // Based the matched pattern, this callback's fuse the input across time steps and replaces with
// single DOT operation <X0|X1|X2|..... Xt>*W
callback_matcher_v2();
callback_matcher_v1();
return modify_graph;
} }
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
......
...@@ -2822,3 +2822,64 @@ TEST(cpu_fusion, dot_batch_forward) ...@@ -2822,3 +2822,64 @@ TEST(cpu_fusion, dot_batch_forward)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f)); EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
} }
} }
static std::shared_ptr<Function>
create_rnn_input_linear_transformation_function(size_t num_timesteps)
{
auto W = std::make_shared<op::Parameter>(element::f32, Shape{400, 50});
auto bias = std::make_shared<op::Parameter>(element::f32, Shape{400});
op::ParameterVector params{W, bias};
auto create_graph = [&]() -> std::shared_ptr<Node> {
auto data_param = std::make_shared<op::Parameter>(element::f32, Shape{10, 1, 50});
params.push_back(data_param);
auto data_param_reshape =
std::make_shared<op::Reshape>(data_param, AxisVector{0, 1, 2}, Shape{10, 50});
auto W_reshape = std::make_shared<op::Reshape>(W, AxisVector{1, 0}, Shape{50, 400});
auto dot = std::make_shared<op::Dot>(data_param_reshape, W_reshape);
auto bias_broadcast = make_shared<op::Broadcast>(bias, dot->get_shape(), AxisSet{0});
auto add_bias = std::make_shared<op::Add>(dot, bias_broadcast);
return add_bias;
};
NodeVector graph_nodes;
for (size_t i = 0; i < num_timesteps; i++)
{
graph_nodes.push_back(create_graph());
}
auto concat = std::make_shared<op::Concat>(graph_nodes, 0);
return make_shared<Function>(NodeVector{concat}, params);
}
TEST(cpu_fusion, fuse_rnn_input_across_time_steps)
{
auto func = create_rnn_input_linear_transformation_function(10);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t ref_matmulbias_count = 1;
auto matmulbias_count = count_ops_of_type<op::MatmulBias>(func);
EXPECT_EQ(ref_matmulbias_count, matmulbias_count);
}
TEST(cpu_fusion, rnn_input_fusion_inter_vs_cpu)
{
shared_ptr<Function> cpu_func = create_rnn_input_linear_transformation_function(10);
shared_ptr<Function> int_func = create_rnn_input_linear_transformation_function(10);
test::Uniform<float> rng(-10.0f, 10.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment