Commit 57fd873d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

give frontends some flexibility over fusions they would like to run (#1010)

* give frontends some flexibility over fusions they would like to run

* address jbobbas feedback
parent 09e07c7b
......@@ -35,21 +35,40 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion()
//30 different fusion groups that we can nest/mix&match/etc
//should be good enough for quite a while
enum fusions
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
//i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
ALL = 0xFFFFFFFF
};
CPUFusion(int fusions = ALL)
: GraphRewrite()
{
construct_matmul();
construct_matmulbias();
construct_fprop_bn();
construct_zero_padded_reshaped_conv();
construct_zero_padded_conv();
construct_zero_padded_conv_backprop_filters();
construct_sigmoid();
construct_sigmoid_bprop();
construct_conv_bias();
construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu();
if (fusions & REGULAR_FUSIONS)
{
construct_matmul();
construct_matmulbias();
construct_fprop_bn();
construct_zero_padded_reshaped_conv();
construct_zero_padded_conv();
construct_zero_padded_conv_backprop_filters();
construct_sigmoid();
construct_sigmoid_bprop();
construct_batch_norm_relu();
construct_batch_norm_relu_global_stats();
construct_conv_relu();
}
if (fusions & DIFFERENTIABLE_FUSIONS)
{
construct_conv_bias();
}
}
private:
......
......@@ -256,7 +256,8 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -277,7 +278,8 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -299,7 +301,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{W, x, b});
pass_manager.run_passes(func);
auto gmm = graph->get_argument(0);
......@@ -320,7 +323,8 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, op::ParameterVector{W, x});
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
......@@ -334,7 +338,8 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3);
......@@ -345,7 +350,8 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path);
......@@ -475,7 +481,8 @@ TEST(cpu_fusion, fuse_conv_bias)
{
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::DIFFERENTIABLE_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "conv_bias.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
......@@ -656,7 +663,8 @@ TEST(cpu_fusion, conv_bias_bprop_n1c1h3w3)
TEST(cpu_fusion, sigmoid_fprop_fusion)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
......@@ -830,7 +838,8 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, op::ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0);
......@@ -1002,7 +1011,8 @@ std::vector<shared_ptr<runtime::TensorView>>
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func);
......@@ -1060,7 +1070,8 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
runtime::cpu::pass::CPUFusion::REGULAR_FUSIONS);
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment