Commit 57fd873d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

give frontends some flexibility over fusions they would like to run (#1010)

* give frontends some flexibility over fusions they would like to run

* address jbobbas feedback
parent 09e07c7b
...@@ -35,21 +35,40 @@ namespace ngraph ...@@ -35,21 +35,40 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CPUFusion() //30 different fusion groups that we can nest/mix&match/etc
//should be good enough for quite a while
enum fusions
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
//i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL = 0xFFFFFFFF
};
CPUFusion(int fusions = ALL)
: GraphRewrite() : GraphRewrite()
{ {
construct_matmul(); if (fusions & REGULAR_FUSIONS)
construct_matmulbias(); {
construct_fprop_bn(); construct_matmul();
construct_zero_padded_reshaped_conv(); construct_matmulbias();
construct_zero_padded_conv(); construct_fprop_bn();
construct_zero_padded_conv_backprop_filters(); construct_zero_padded_reshaped_conv();
construct_sigmoid(); construct_zero_padded_conv();
construct_sigmoid_bprop(); construct_zero_padded_conv_backprop_filters();
construct_conv_bias(); construct_sigmoid();
construct_batch_norm_relu(); construct_sigmoid_bprop();
construct_batch_norm_relu_global_stats();
construct_conv_relu(); construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu();
}
if (fusions & DIFFERENTIABLE_FUSIONS)
{
construct_conv_bias();
}
} }
private: private:
......
...@@ -256,7 +256,8 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -256,7 +256,8 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast; auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -277,7 +278,8 @@ TEST(cpu_fusion, commutative_matmul_bias) ...@@ -277,7 +278,8 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot; auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -299,7 +301,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias) ...@@ -299,7 +301,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{W, x, b}); auto func = make_shared<Function>(graph, op::ParameterVector{W, x, b});
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto gmm = graph->get_argument(0); auto gmm = graph->get_argument(0);
...@@ -320,7 +323,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias) ...@@ -320,7 +323,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot); auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{W, x}); auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func); size_t mmb = count_ops_of_type<op::MatmulBias>(func);
...@@ -334,7 +338,8 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -334,7 +338,8 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func); auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3); ASSERT_EQ(mmbs, 3);
...@@ -345,7 +350,8 @@ TEST(cpu_fusion, fuse_fprop_bn) ...@@ -345,7 +350,8 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
...@@ -475,7 +481,8 @@ TEST(cpu_fusion, fuse_conv_bias) ...@@ -475,7 +481,8 @@ TEST(cpu_fusion, fuse_conv_bias)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::DIFFERENTIABLE_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
...@@ -656,7 +663,8 @@ TEST(cpu_fusion, conv_bias_bprop_n1c1h3w3) ...@@ -656,7 +663,8 @@ TEST(cpu_fusion, conv_bias_bprop_n1c1h3w3)
TEST(cpu_fusion, sigmoid_fprop_fusion) TEST(cpu_fusion, sigmoid_fprop_fusion)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string); stringstream ss(json_string);
...@@ -830,7 +838,8 @@ TEST(cpu_fusion, fuse_conv_relu) ...@@ -830,7 +838,8 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, op::ParameterVector{A, weights}); auto func = make_shared<Function>(abs_node, op::ParameterVector{A, weights});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func); size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0); ASSERT_GT(cb, 0);
...@@ -1002,7 +1011,8 @@ std::vector<shared_ptr<runtime::TensorView>> ...@@ -1002,7 +1011,8 @@ std::vector<shared_ptr<runtime::TensorView>>
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op. // check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func); size_t count = count_ops_of_type<op::MatmulBias>(func);
...@@ -1060,7 +1070,8 @@ TEST(cpu_fusion, rnn_fusion_from_json_model) ...@@ -1060,7 +1070,8 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path = const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json"); file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment