Commit 2db236b7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

RNN Fusion using Pattern Matcher (#741)

* initial refactoring using PM

* unit test pass

* cosmetic changes

* add another rnn test

* address louis' feedback

* lower-case labels
parent 6909850e
...@@ -29,197 +29,110 @@ ...@@ -29,197 +29,110 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
using namespace ngraph; #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
typedef std::shared_ptr<Node> NodePtr;
#define TI(x) std::type_index(typeid(x)) using namespace ngraph;
// a sequence of nodes, identified with a segment type for the input parameter type struct Type
struct NodeSegment : public NodeVector
{ {
enum Type enum
{ {
DATA = 0, DATA = 0,
WEIGHTS, WEIGHTS,
BIAS, BIAS,
UNDEFINED
}; };
Type type{UNDEFINED};
}; };
typedef std::pair<NodeSegment::Type, std::vector<std::type_index>> NodeTypeSequence;
typedef std::list<NodeTypeSequence> NodeTypeSequenceList;
// Preorder traversal to collect all valid segments in the graph static std::shared_ptr<Node> construct_data_pattern(std::shared_ptr<pattern::op::Label> data_slice)
// precondition: all valid sequences must be unique
// [a, b, c] and [a, c, b] are different, for valid sequences like [a, b, c] and [a, b], the
// longest sequence will be matched.
void FindValidSegments(const NodePtr& node,
NodeSegment segment,
std::vector<NodeSegment>& segment_bundle,
NodeTypeSequenceList valid_sequence_list,
int depth)
{ {
const Node& node_ref = *node; auto reshape_slice =
// check current node against all valid sequences at current depth level. Remove sequences std::make_shared<op::Reshape>(data_slice, AxisVector{0, 1, 2}, Shape{2, 4});
// which does not match current node type auto W = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1});
for (auto seq_it = valid_sequence_list.begin(); seq_it != valid_sequence_list.end();) auto dot = std::make_shared<op::Dot>(reshape_slice, W);
{ auto broadcast = std::make_shared<pattern::op::Label>(element::f32, dot->get_shape());
const auto& valid_seq = seq_it->second; return dot + broadcast;
// remove sequences which are too short or doesn't match current node type at depth index
if (depth >= valid_seq.size() || TI(node_ref) != valid_seq[depth])
{
seq_it = valid_sequence_list.erase(seq_it);
}
else
{
++seq_it;
}
}
// postconditions:
// valid_sequnce_list.size() > 0 : there's still valid sequences to match
// otherwise : terminate
if (valid_sequence_list.size() > 0)
{
segment.push_back(node);
// base case, we have one valid segment left (since valid sequences are expected to be
// unique), and current depth matches (sequence-length - 1) (i.e. last node)
// we found a match
if (valid_sequence_list.size() == 1 &&
depth == (valid_sequence_list.front().second.size() - 1))
{
segment.type = valid_sequence_list.front().first;
segment_bundle.push_back(segment);
return;
}
// still have more than one sequences to check, continue traversal
else
{
const auto outputs = node->get_users();
for (const auto& out_node : outputs)
{
FindValidSegments(
out_node, segment, segment_bundle, valid_sequence_list, depth + 1);
}
}
}
} }
// this is the expected sequence count for fusing static std::shared_ptr<Node>
const size_t SEGMENT_COUNT = 3; construct_weights_pattern(std::shared_ptr<pattern::op::Label> weights_reshape)
struct OrderedParams
{ {
public: auto X = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 4});
OrderedParams() auto dot = std::make_shared<op::Dot>(X, weights_reshape);
: m_params{{nullptr, nullptr, nullptr}} auto broadcast = std::make_shared<pattern::op::Label>(element::f32, dot->get_shape());
{ return dot + broadcast;
} }
bool valid()
{
return std::none_of(m_params.cbegin(), m_params.cend(), [](const NodePtr& n) -> bool {
return n == nullptr;
});
}
void set(const NodeSegment::Type type, const NodePtr& node) { m_params[type] = node; }
NodePtr get(const NodeSegment::Type type) const { return m_params.at(type); }
friend bool operator<(const OrderedParams& a, const OrderedParams& b);
private:
// order based on NodeSegment::Type
// <data, weights, bias>
std::array<NodePtr, SEGMENT_COUNT> m_params;
};
bool operator<(const OrderedParams& a, const OrderedParams& b) static std::shared_ptr<Node>
construct_bias_pattern(std::shared_ptr<pattern::op::Label> bias_broadcast)
{ {
return a.m_params < b.m_params; auto dot_label = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1});
return dot_label + bias_broadcast;
} }
bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Function> function) bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Function> function)
{ {
bool modified = false; bool modified = false;
const NodeTypeSequenceList valid_sequences{ auto data_pred = [](std::shared_ptr<Node> n) {
{NodeSegment::DATA, return std::dynamic_pointer_cast<op::Slice>(n) != nullptr;
{TI(op::Parameter), TI(op::Slice), TI(op::Reshape), TI(op::Dot), TI(op::Add)}}, };
{NodeSegment::WEIGHTS, {TI(op::Parameter), TI(op::Reshape), TI(op::Dot), TI(op::Add)}}, auto data_slice = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 4}, data_pred);
{NodeSegment::BIAS, {TI(op::Parameter), TI(op::Broadcast), TI(op::Add)}}}; auto data_pattern = construct_data_pattern(data_slice);
// find all parameter nodes
std::vector<NodePtr> param_nodes;
for (auto& node : function->get_ordered_ops())
{
if (node->is_parameter())
{
param_nodes.push_back(node);
}
}
// iterate all parameters and find all valid segments
std::vector<NodeSegment> segment_bundle;
for (auto& node : param_nodes)
{
NodeSegment segment;
FindValidSegments(node, segment, segment_bundle, valid_sequences, 0);
}
// combined all segments by last operator auto weights_pred = [](std::shared_ptr<Node> n) {
std::map<NodePtr, std::vector<NodeSegment>> op_seg_map; return std::dynamic_pointer_cast<op::Reshape>(n) != nullptr;
for (const auto& segment : segment_bundle) };
auto weights_reshape =
std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1}, weights_pred);
auto weights_pattern = construct_weights_pattern(weights_reshape);
//we don't really need a broadcast node but
//labelling a Broadcast allows us to extract
//params from all 3 labels in the same fashion
//(i.e. via get_input_op(0))
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bias_broadcast =
std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, broadcast_pred);
auto bias_pattern = construct_bias_pattern(bias_broadcast);
const size_t NUM_MMB_ARGS = 3;
std::shared_ptr<pattern::op::Label> labels[] = {data_slice, weights_reshape, bias_broadcast};
//Matchers' ordering is important! Don't change!
std::shared_ptr<pattern::Matcher> matchers[] = {
std::make_shared<pattern::Matcher>(data_pattern),
std::make_shared<pattern::Matcher>(weights_pattern),
std::make_shared<pattern::Matcher>(bias_pattern)};
std::map<std::shared_ptr<Node>, NodeVector> op_seg_map; //add to list of params
std::map<NodeVector, NodeVector> param_list;
for (auto n : function->get_ordered_ops())
{ {
auto op_it = op_seg_map.find(segment.back()); NodeVector params;
if (op_it == op_seg_map.end()) NodeVector matched_nodes;
for (size_t i = 0; i < NUM_MMB_ARGS; i++)
{ {
auto insert_result = op_seg_map.insert( auto matcher = matchers[i];
std::make_pair(segment.back(), std::vector<NodeSegment>(SEGMENT_COUNT))); if (matcher->match(n))
op_it = insert_result.first; {
} //if we get all 3 matches they will all fall
(op_it->second)[segment.type] = segment; //in the right spots (e.g. DATA, WEIGHTS, BIAS) since matchers are ordered
} //if we have less than 3 matches we skip this node anyways
auto matched = matcher->get_pattern_map()[labels[i]];
params.push_back(matched->get_input_op(0));
matched_nodes.push_back(matched);
}
// remove ops with less than SEGMENT_COUNT number of segments if (params.size() != NUM_MMB_ARGS)
for (auto op_it = op_seg_map.cbegin(); op_it != op_seg_map.cend();)
{
// remove ops with less than expected segements
bool valid = true;
for (auto& seg : op_it->second)
{
if (seg.empty())
{ {
valid = false; continue;
break;
} }
}
if (!valid)
{
op_it = op_seg_map.erase(op_it);
}
else
{
++op_it;
}
}
// create a lookup map for each unique set of parameters //we have a full set for the current Add (n) i.e. data, weights, bias
std::map<OrderedParams, NodeVector> param_list; op_seg_map.insert(std::make_pair(n, matched_nodes));
for (auto& op_seg : op_seg_map) param_list[params].push_back(n);
{
std::vector<NodeSegment>& segments = op_seg.second;
OrderedParams p;
// put each segment's parameter in the OrderedParams by type
for (auto& seg : segments)
{
p.set(seg.type, seg[0]);
}
// if any of them is missing, p will be invalid
// this can happen for example, when two of them are both
// weights
if (p.valid())
{
param_list[p].push_back(op_seg.first);
} }
} }
...@@ -242,25 +155,22 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -242,25 +155,22 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
// iterate each unique set of parameters, replace original operations // iterate each unique set of parameters, replace original operations
for (auto& p : param_list) for (auto& p : param_list)
{ {
OrderedParams params = p.first; NodeVector params = p.first;
NodeVector& op_nodes = p.second; NodeVector& op_nodes = p.second;
auto data_node = params.get(NodeSegment::DATA); auto data_node = params.at(Type::DATA);
auto weights_node = params.get(NodeSegment::WEIGHTS); auto weights_node = params.at(Type::WEIGHTS);
auto bias_node = params.get(NodeSegment::BIAS); auto bias_node = params.at(Type::BIAS);
const auto& data_shape = data_node->get_shape(); const auto& data_shape = data_node->get_shape();
// get the first combo op
auto first_op = op_nodes[0];
auto first_weights_segment = op_seg_map[first_op][NodeSegment::WEIGHTS];
// construct new op nodes // construct new op nodes
AxisVector data_order(data_node->get_shape().size()); AxisVector data_order(data_node->get_shape().size());
std::iota(begin(data_order), end(data_order), 0); std::iota(begin(data_order), end(data_order), 0);
auto data_reshape_node = std::make_shared<op::Reshape>( auto data_reshape_node = std::make_shared<op::Reshape>(
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]}); data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto weights_reshape_node = first_weights_segment[1]->copy_with_new_args({weights_node});
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS);
auto weights_reshape_node = old_weights_reshape_node->copy_with_new_args({weights_node});
auto dot_node = std::make_shared<op::Dot>(data_reshape_node, weights_reshape_node); auto dot_node = std::make_shared<op::Dot>(data_reshape_node, weights_reshape_node);
const auto& dot_shape = dot_node->get_shape(); const auto& dot_shape = dot_node->get_shape();
...@@ -272,8 +182,8 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -272,8 +182,8 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
// create a slice for each user of the dot op matching the original dot op's output // create a slice for each user of the dot op matching the original dot op's output
for (auto op : op_nodes) for (auto op : op_nodes)
{ {
const auto& cur_data_segment = op_seg_map[op][NodeSegment::DATA]; const auto old_slice =
const auto old_slice = std::dynamic_pointer_cast<op::Slice>(cur_data_segment[1]); std::dynamic_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA));
const auto& old_lower_bounds = old_slice->get_lower_bounds(); const auto& old_lower_bounds = old_slice->get_lower_bounds();
// lower bound matching the current time step // lower bound matching the current time step
const Coordinate lower_bounds{old_lower_bounds[1], 0}; const Coordinate lower_bounds{old_lower_bounds[1], 0};
......
...@@ -346,8 +346,8 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -346,8 +346,8 @@ TEST(cpu_fusion, gemm_mlp)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func); auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 3); ASSERT_EQ(mmbs, 3);
} }
TEST(cpu_fusion, fuse_fprop_bn) TEST(cpu_fusion, fuse_fprop_bn)
...@@ -1071,3 +1071,28 @@ TEST(cpu_fusion, rnn_matrix_fusion_eval_pass) ...@@ -1071,3 +1071,28 @@ TEST(cpu_fusion, rnn_matrix_fusion_eval_pass)
EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i])); EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i]));
} }
} }
TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
const size_t NUM_STEPS = 10;
auto mmb_predicate = [NUM_STEPS](std::shared_ptr<Node> node) {
auto users = node->get_users();
return users.size() == NUM_STEPS &&
std::all_of(begin(users), end(users), [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Slice>(n) != nullptr;
});
};
auto mmbs = get_ops_of_type<op::MatmulBias>(func);
ASSERT_TRUE(std::any_of(begin(mmbs), end(mmbs), mmb_predicate));
}
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -76,6 +76,21 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve ...@@ -76,6 +76,21 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve
tv->write(values.data(), 0, values.size() * sizeof(T)); tv->write(values.data(), 0, values.size() * sizeof(T));
} }
template <typename T>
std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
std::vector<std::shared_ptr<T>> ops;
for (auto op : f->get_ops())
{
if (auto cop = std::dynamic_pointer_cast<T>(op))
{
ops.push_back(cop);
}
}
return ops;
}
template <typename T> template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f) size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment