Commit 2db236b7 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

RNN Fusion using Pattern Matcher (#741)

* initial refactoring using PM

* unit test pass

* cosmetic changes

* add another rnn test

* address louis' feedback

* lower-case labels
parent 6909850e
...@@ -346,8 +346,8 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -346,8 +346,8 @@ TEST(cpu_fusion, gemm_mlp)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func); auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmb, 3); ASSERT_EQ(mmbs, 3);
} }
TEST(cpu_fusion, fuse_fprop_bn) TEST(cpu_fusion, fuse_fprop_bn)
...@@ -1071,3 +1071,28 @@ TEST(cpu_fusion, rnn_matrix_fusion_eval_pass) ...@@ -1071,3 +1071,28 @@ TEST(cpu_fusion, rnn_matrix_fusion_eval_pass)
EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i])); EXPECT_TRUE(test::all_close<float>(result_expected[i], result_fused[i]));
} }
} }
TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass_manager.run_passes(func);
const size_t NUM_STEPS = 10;
auto mmb_predicate = [NUM_STEPS](std::shared_ptr<Node> node) {
auto users = node->get_users();
return users.size() == NUM_STEPS &&
std::all_of(begin(users), end(users), [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Slice>(n) != nullptr;
});
};
auto mmbs = get_ops_of_type<op::MatmulBias>(func);
ASSERT_TRUE(std::any_of(begin(mmbs), end(mmbs), mmb_predicate));
}
This diff is collapsed.
...@@ -76,6 +76,21 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve ...@@ -76,6 +76,21 @@ void write_vector(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::ve
tv->write(values.data(), 0, values.size() * sizeof(T)); tv->write(values.data(), 0, values.size() * sizeof(T));
} }
template <typename T>
std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
std::vector<std::shared_ptr<T>> ops;
for (auto op : f->get_ops())
{
if (auto cop = std::dynamic_pointer_cast<T>(op))
{
ops.push_back(cop);
}
}
return ops;
}
template <typename T> template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f) size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment