Commit 2db236b7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

RNN Fusion using Pattern Matcher (#741)

* initial refactoring using PM

* unit test pass

* cosmetic changes

* add another rnn test

* address louis' feedback

* lower-case labels
parent 6909850e
......@@ -29,197 +29,110 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
using namespace ngraph;
typedef std::shared_ptr<Node> NodePtr;
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#define TI(x) std::type_index(typeid(x))
using namespace ngraph;
// a sequence of nodes, identified with a segment type for the input parameter type
struct NodeSegment : public NodeVector
struct Type
{
enum Type
enum
{
DATA = 0,
WEIGHTS,
BIAS,
UNDEFINED
};
Type type{UNDEFINED};
};
typedef std::pair<NodeSegment::Type, std::vector<std::type_index>> NodeTypeSequence;
typedef std::list<NodeTypeSequence> NodeTypeSequenceList;
// Preorder traversal to collect all valid segments in the graph
// precondition: all valid sequences must be unique
// [a, b, c] and [a, c, b] are different, for valid sequences like [a, b, c] and [a, b], the
// longest sequence will be matched.
void FindValidSegments(const NodePtr& node,
NodeSegment segment,
std::vector<NodeSegment>& segment_bundle,
NodeTypeSequenceList valid_sequence_list,
int depth)
static std::shared_ptr<Node> construct_data_pattern(std::shared_ptr<pattern::op::Label> data_slice)
{
const Node& node_ref = *node;
// check current node against all valid sequences at current depth level. Remove sequences
// which does not match current node type
for (auto seq_it = valid_sequence_list.begin(); seq_it != valid_sequence_list.end();)
{
const auto& valid_seq = seq_it->second;
// remove sequences which are too short or doesn't match current node type at depth index
if (depth >= valid_seq.size() || TI(node_ref) != valid_seq[depth])
{
seq_it = valid_sequence_list.erase(seq_it);
}
else
{
++seq_it;
}
}
// postconditions:
// valid_sequnce_list.size() > 0 : there's still valid sequences to match
// otherwise : terminate
if (valid_sequence_list.size() > 0)
{
segment.push_back(node);
// base case, we have one valid segment left (since valid sequences are expected to be
// unique), and current depth matches (sequence-length - 1) (i.e. last node)
// we found a match
if (valid_sequence_list.size() == 1 &&
depth == (valid_sequence_list.front().second.size() - 1))
{
segment.type = valid_sequence_list.front().first;
segment_bundle.push_back(segment);
return;
}
// still have more than one sequences to check, continue traversal
else
{
const auto outputs = node->get_users();
for (const auto& out_node : outputs)
{
FindValidSegments(
out_node, segment, segment_bundle, valid_sequence_list, depth + 1);
}
}
}
auto reshape_slice =
std::make_shared<op::Reshape>(data_slice, AxisVector{0, 1, 2}, Shape{2, 4});
auto W = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1});
auto dot = std::make_shared<op::Dot>(reshape_slice, W);
auto broadcast = std::make_shared<pattern::op::Label>(element::f32, dot->get_shape());
return dot + broadcast;
}
// this is the expected sequence count for fusing
const size_t SEGMENT_COUNT = 3;
struct OrderedParams
static std::shared_ptr<Node>
construct_weights_pattern(std::shared_ptr<pattern::op::Label> weights_reshape)
{
public:
OrderedParams()
: m_params{{nullptr, nullptr, nullptr}}
{
}
bool valid()
{
return std::none_of(m_params.cbegin(), m_params.cend(), [](const NodePtr& n) -> bool {
return n == nullptr;
});
}
void set(const NodeSegment::Type type, const NodePtr& node) { m_params[type] = node; }
NodePtr get(const NodeSegment::Type type) const { return m_params.at(type); }
friend bool operator<(const OrderedParams& a, const OrderedParams& b);
private:
// order based on NodeSegment::Type
// <data, weights, bias>
std::array<NodePtr, SEGMENT_COUNT> m_params;
};
auto X = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 4});
auto dot = std::make_shared<op::Dot>(X, weights_reshape);
auto broadcast = std::make_shared<pattern::op::Label>(element::f32, dot->get_shape());
return dot + broadcast;
}
bool operator<(const OrderedParams& a, const OrderedParams& b)
static std::shared_ptr<Node>
construct_bias_pattern(std::shared_ptr<pattern::op::Label> bias_broadcast)
{
return a.m_params < b.m_params;
auto dot_label = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1});
return dot_label + bias_broadcast;
}
bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Function> function)
{
bool modified = false;
const NodeTypeSequenceList valid_sequences{
{NodeSegment::DATA,
{TI(op::Parameter), TI(op::Slice), TI(op::Reshape), TI(op::Dot), TI(op::Add)}},
{NodeSegment::WEIGHTS, {TI(op::Parameter), TI(op::Reshape), TI(op::Dot), TI(op::Add)}},
{NodeSegment::BIAS, {TI(op::Parameter), TI(op::Broadcast), TI(op::Add)}}};
// find all parameter nodes
std::vector<NodePtr> param_nodes;
for (auto& node : function->get_ordered_ops())
{
if (node->is_parameter())
{
param_nodes.push_back(node);
}
}
auto data_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Slice>(n) != nullptr;
};
auto data_slice = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 4}, data_pred);
auto data_pattern = construct_data_pattern(data_slice);
// iterate all parameters and find all valid segments
std::vector<NodeSegment> segment_bundle;
for (auto& node : param_nodes)
{
NodeSegment segment;
FindValidSegments(node, segment, segment_bundle, valid_sequences, 0);
}
auto weights_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Reshape>(n) != nullptr;
};
auto weights_reshape =
std::make_shared<pattern::op::Label>(element::f32, Shape{4, 1}, weights_pred);
auto weights_pattern = construct_weights_pattern(weights_reshape);
//we don't really need a broadcast node but
//labelling a Broadcast allows us to extract
//params from all 3 labels in the same fashion
//(i.e. via get_input_op(0))
auto broadcast_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bias_broadcast =
std::make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, broadcast_pred);
auto bias_pattern = construct_bias_pattern(bias_broadcast);
// combined all segments by last operator
std::map<NodePtr, std::vector<NodeSegment>> op_seg_map;
for (const auto& segment : segment_bundle)
{
auto op_it = op_seg_map.find(segment.back());
if (op_it == op_seg_map.end())
{
auto insert_result = op_seg_map.insert(
std::make_pair(segment.back(), std::vector<NodeSegment>(SEGMENT_COUNT)));
op_it = insert_result.first;
}
(op_it->second)[segment.type] = segment;
}
const size_t NUM_MMB_ARGS = 3;
std::shared_ptr<pattern::op::Label> labels[] = {data_slice, weights_reshape, bias_broadcast};
//Matchers' ordering is important! Don't change!
std::shared_ptr<pattern::Matcher> matchers[] = {
std::make_shared<pattern::Matcher>(data_pattern),
std::make_shared<pattern::Matcher>(weights_pattern),
std::make_shared<pattern::Matcher>(bias_pattern)};
// remove ops with less than SEGMENT_COUNT number of segments
for (auto op_it = op_seg_map.cbegin(); op_it != op_seg_map.cend();)
{
// remove ops with less than expected segements
bool valid = true;
for (auto& seg : op_it->second)
std::map<std::shared_ptr<Node>, NodeVector> op_seg_map; //add to list of params
std::map<NodeVector, NodeVector> param_list;
for (auto n : function->get_ordered_ops())
{
if (seg.empty())
NodeVector params;
NodeVector matched_nodes;
for (size_t i = 0; i < NUM_MMB_ARGS; i++)
{
valid = false;
break;
}
}
if (!valid)
{
op_it = op_seg_map.erase(op_it);
}
else
auto matcher = matchers[i];
if (matcher->match(n))
{
++op_it;
}
//if we get all 3 matches they will all fall
//in the right spots (e.g. DATA, WEIGHTS, BIAS) since matchers are ordered
//if we have less than 3 matches we skip this node anyways
auto matched = matcher->get_pattern_map()[labels[i]];
params.push_back(matched->get_input_op(0));
matched_nodes.push_back(matched);
}
// create a lookup map for each unique set of parameters
std::map<OrderedParams, NodeVector> param_list;
for (auto& op_seg : op_seg_map)
{
std::vector<NodeSegment>& segments = op_seg.second;
OrderedParams p;
// put each segment's parameter in the OrderedParams by type
for (auto& seg : segments)
if (params.size() != NUM_MMB_ARGS)
{
p.set(seg.type, seg[0]);
continue;
}
// if any of them is missing, p will be invalid
// this can happen for example, when two of them are both
// weights
if (p.valid())
{
param_list[p].push_back(op_seg.first);
//we have a full set for the current Add (n) i.e. data, weights, bias
op_seg_map.insert(std::make_pair(n, matched_nodes));
param_list[params].push_back(n);
}
}
......@@ -242,25 +155,22 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
// iterate each unique set of parameters, replace original operations
for (auto& p : param_list)
{
OrderedParams params = p.first;
NodeVector params = p.first;
NodeVector& op_nodes = p.second;
auto data_node = params.get(NodeSegment::DATA);
auto weights_node = params.get(NodeSegment::WEIGHTS);
auto bias_node = params.get(NodeSegment::BIAS);
auto data_node = params.at(Type::DATA);
auto weights_node = params.at(Type::WEIGHTS);
auto bias_node = params.at(Type::BIAS);
const auto& data_shape = data_node->get_shape();
// get the first combo op
auto first_op = op_nodes[0];
auto first_weights_segment = op_seg_map[first_op][NodeSegment::WEIGHTS];
// construct new op nodes
AxisVector data_order(data_node->get_shape().size());
std::iota(begin(data_order), end(data_order), 0);
auto data_reshape_node = std::make_shared<op::Reshape>(
data_node, data_order, Shape{data_shape[0] * data_shape[1], data_shape[2]});
auto weights_reshape_node = first_weights_segment[1]->copy_with_new_args({weights_node});
auto old_weights_reshape_node = op_seg_map.at(op_nodes.at(0)).at(Type::WEIGHTS);
auto weights_reshape_node = old_weights_reshape_node->copy_with_new_args({weights_node});
auto dot_node = std::make_shared<op::Dot>(data_reshape_node, weights_reshape_node);
const auto& dot_shape = dot_node->get_shape();
......@@ -272,8 +182,8 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
// create a slice for each user of the dot op matching the original dot op's output
for (auto op : op_nodes)
{
const auto& cur_data_segment = op_seg_map[op][NodeSegment::DATA];
const auto old_slice = std::dynamic_pointer_cast<op::Slice>(cur_data_segment[1]);
const auto old_slice =
std::dynamic_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA));
const auto& old_lower_bounds = old_slice->get_lower_bounds();
// lower bound matching the current time step
const Coordinate lower_bounds{old_lower_bounds[1], 0};
......
......@@ -346,8 +346,8 @@ TEST(cpu_fusion, gemm_mlp)
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 3);
auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3);
}
TEST(cpu_fusion, fuse_fprop_bn)
......@@ -1071,3 +1071,28 @@ TEST(cpu_fusion, rnn_matrix_fusion_eval_pass)
EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i]));
}
}
TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
const size_t NUM_STEPS = 10;
auto mmb_predicate = [NUM_STEPS](std::shared_ptr<Node> node) {
auto users = node->get_users();
return users.size() == NUM_STEPS &&
std::all_of(begin(users), end(users), [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Slice>(n) != nullptr;
});
};
auto mmbs = get_ops_of_type<op::MatmulBias>(func);
ASSERT_TRUE(std::any_of(begin(mmbs), end(mmbs), mmb_predicate));
}
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -76,6 +76,21 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve
tv->write(values.data(), 0, values.size() * sizeof(T));
}
template <typename T>
std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
std::vector<std::shared_ptr<T>> ops;
for (auto op : f->get_ops())
{
if (auto cop = std::dynamic_pointer_cast<T>(op))
{
ops.push_back(cop);
}
}
return ops;
}
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment