Commit a65b5155 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

Change FusionType to enum class and use EnumMask (#2957)

* constexpr ctor for EnumMask

* added pass properties to core passes.

* change fusion type to have better type safety.

* refactor to use enum mask.

* remove extra code.

* added constants for FusionType backward compatibility.

* spelling.

* grammar fix.
parent da1cacde
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
## Passes ## Passes
* `LikeReplacement` pass must be run by all transformers. * `LikeReplacement` pass must be run by all transformers.
* `ngraph::pass::FusionType` is now an enum class. Constant values defined by `FusionType` are created for backward compatibility and will be removed in future releases.
## Nodes, Parameters ## Nodes, Parameters
......
...@@ -172,7 +172,7 @@ bool ngraph::pass::BatchFusion::run_on_function(std::shared_ptr<Function> func) ...@@ -172,7 +172,7 @@ bool ngraph::pass::BatchFusion::run_on_function(std::shared_ptr<Function> func)
const Node& node = *n; const Node& node = *n;
if (TI(node) == TI(op::Concat)) if (TI(node) == TI(op::Concat))
{ {
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS) if (m_fusion_type.is_set(FusionType::REGULAR_FUSIONS))
{ {
if (auto fused_conv = fuse_group_convolution(n)) if (auto fused_conv = fuse_group_convolution(n))
{ {
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
class BatchFusion : public ngraph::pass::FunctionPass class BatchFusion : public ngraph::pass::FunctionPass
{ {
public: public:
BatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS) BatchFusion(FusionTypeMask type = FusionType::ALL_FUSIONS)
: FunctionPass() : FunctionPass()
, m_fusion_type(type) , m_fusion_type(type)
{ {
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override; virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private: private:
ngraph::pass::FusionType m_fusion_type; FusionTypeMask m_fusion_type;
}; };
} }
} }
...@@ -29,10 +29,10 @@ namespace ngraph ...@@ -29,10 +29,10 @@ namespace ngraph
class ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite class ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CoreFusion(ngraph::pass::FusionType fusions = ngraph::pass::REGULAR_FUSIONS) CoreFusion(FusionTypeMask fusions = FusionType::REGULAR_FUSIONS)
: GraphRewrite() : GraphRewrite()
{ {
if (fusions & ngraph::pass::REGULAR_FUSIONS) if (fusions.is_set(FusionType::REGULAR_FUSIONS))
{ {
construct_relu(); construct_relu();
construct_folded_batch_norm(); construct_folded_batch_norm();
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
// be all supported by certain backends. In such a case, backends // be all supported by certain backends. In such a case, backends
// can register a FusedOpDecomposition pass after CoreFusion that will // can register a FusedOpDecomposition pass after CoreFusion that will
// selectively decompose the unsupported ops back to the Core opset // selectively decompose the unsupported ops back to the Core opset
if (fusions & ngraph::pass::FOP_FUSIONS) if (fusions.is_set(FusionType::FOP_FUSIONS))
{ {
construct_conv_bias(); construct_conv_bias();
construct_conv_bias_add(); construct_conv_bias_add();
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
class NodePass; class NodePass;
class CallGraphPass; class CallGraphPass;
class Manager; class Manager;
enum FusionType enum class FusionType : uint32_t
{ {
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints` // i.e. implement `generate_adjoints`
...@@ -46,6 +46,18 @@ namespace ngraph ...@@ -46,6 +46,18 @@ namespace ngraph
FOP_FUSIONS = 0x4, FOP_FUSIONS = 0x4,
ALL_FUSIONS = 0xFFFFFFFF ALL_FUSIONS = 0xFFFFFFFF
}; };
typedef EnumMask<FusionType> FusionTypeMask;
// These constants are for backward compatibility only, will deprecate soon.
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType DIFFERENTIABLE_FUSIONS = FusionType::DIFFERENTIABLE_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType REGULAR_FUSIONS = FusionType::REGULAR_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType FOP_FUSIONS = FusionType::FOP_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType ALL_FUSIONS = FusionType::ALL_FUSIONS;
enum class PassProperty : uint32_t enum class PassProperty : uint32_t
{ {
// Pass requires node shapes to be static // Pass requires node shapes to be static
......
...@@ -1193,7 +1193,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1193,7 +1193,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(CoreFusion, true, ngraph::pass, ngraph::pass::ALL_FUSIONS); REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported); REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
......
...@@ -37,16 +37,18 @@ namespace ngraph ...@@ -37,16 +37,18 @@ namespace ngraph
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{ {
public: public:
CPUFusion(ngraph::pass::FusionType fusions = ngraph::pass::ALL_FUSIONS) typedef ngraph::pass::FusionType FusionType;
typedef ngraph::pass::FusionTypeMask FusionTypeMask;
CPUFusion(FusionTypeMask fusions = FusionType::ALL_FUSIONS)
: GraphRewrite() : GraphRewrite()
{ {
if (fusions & ngraph::pass::DIFFERENTIABLE_FUSIONS) if (fusions.is_set(FusionType::DIFFERENTIABLE_FUSIONS))
{ {
construct_conv_bias(); // DEPRECATED - Use CoreFusion construct_conv_bias(); // DEPRECATED - Use CoreFusion
construct_sigmoid_multiply(); construct_sigmoid_multiply();
} }
if (fusions & ngraph::pass::REGULAR_FUSIONS) if (fusions.is_set(FusionType::REGULAR_FUSIONS))
{ {
construct_matmul(); construct_matmul();
construct_matmulbias(); construct_matmulbias();
......
...@@ -589,7 +589,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio ...@@ -589,7 +589,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
const Node& node = *n; const Node& node = *n;
if (TI(node) == TI(op::Concat)) if (TI(node) == TI(op::Concat))
{ {
if (m_fusion_type & ngraph::pass::DIFFERENTIABLE_FUSIONS) if (m_fusion_type.is_set(FusionType::DIFFERENTIABLE_FUSIONS))
{ {
if (auto fused_node = fuse_batch_mat_mul_transpose(n)) if (auto fused_node = fuse_batch_mat_mul_transpose(n))
{ {
...@@ -597,7 +597,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio ...@@ -597,7 +597,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
modified = true; modified = true;
} }
} }
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS) if (m_fusion_type.is_set(FusionType::REGULAR_FUSIONS))
{ {
/* /*
if (auto fused_conv = fuse_group_convolution(n)) if (auto fused_conv = fuse_group_convolution(n))
......
...@@ -36,16 +36,18 @@ namespace ngraph ...@@ -36,16 +36,18 @@ namespace ngraph
class CPU_BACKEND_API CPUBatchFusion : public ngraph::pass::FunctionPass class CPU_BACKEND_API CPUBatchFusion : public ngraph::pass::FunctionPass
{ {
public: public:
CPUBatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS) typedef ngraph::pass::FusionType FusionType;
typedef ngraph::pass::FusionTypeMask FusionTypeMask;
CPUBatchFusion(FusionTypeMask fusions = FusionType::ALL_FUSIONS)
: FunctionPass() : FunctionPass()
, m_fusion_type(type) , m_fusion_type(fusions)
{ {
} }
virtual bool virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override; run_on_function(std::shared_ptr<ngraph::Function> function) override;
private: private:
ngraph::pass::FusionType m_fusion_type; FusionTypeMask m_fusion_type;
}; };
} }
} }
......
...@@ -433,7 +433,7 @@ shared_ptr<runtime::Executable> ...@@ -433,7 +433,7 @@ shared_ptr<runtime::Executable>
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(ngraph::pass::ALL_FUSIONS); pass_manager.register_pass<ngraph::pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
// GetOutputElementElimination must be after CommonSubexpressionElimination // GetOutputElementElimination must be after CommonSubexpressionElimination
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>(); pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
......
...@@ -339,7 +339,7 @@ TEST(core_fusion, conv_bias) ...@@ -339,7 +339,7 @@ TEST(core_fusion, conv_bias)
auto decomp_f2 = gen_f(false); auto decomp_f2 = gen_f(false);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS); pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1); pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1); ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1);
...@@ -390,7 +390,7 @@ TEST(core_fusion, conv_bias_bcast_reshape) ...@@ -390,7 +390,7 @@ TEST(core_fusion, conv_bias_bcast_reshape)
auto decomp_f2 = gen_f(false); auto decomp_f2 = gen_f(false);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS); pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1); pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1); ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1);
...@@ -442,7 +442,7 @@ TEST(core_fusion, conv_bias_add) ...@@ -442,7 +442,7 @@ TEST(core_fusion, conv_bias_add)
auto decomp_f2 = gen_f(false); auto decomp_f2 = gen_f(false);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS); pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1); pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasAdd>(decomp_f1), 1); ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasAdd>(decomp_f1), 1);
...@@ -511,7 +511,7 @@ TEST(core_fusion, DISABLED_conv_bias_bprop) ...@@ -511,7 +511,7 @@ TEST(core_fusion, DISABLED_conv_bias_bprop)
auto decomp_f2 = gen_f(false); auto decomp_f2 = gen_f(false);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS); pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1); pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasBackpropFiltersBias>(decomp_f1), 1); ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasBackpropFiltersBias>(decomp_f1), 1);
......
...@@ -288,7 +288,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic) ...@@ -288,7 +288,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast; auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -309,7 +309,7 @@ TEST(cpu_fusion, commutative_matmul_bias) ...@@ -309,7 +309,7 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot; auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C}); auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func); pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr); ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
...@@ -331,7 +331,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias) ...@@ -331,7 +331,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add); auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x, b}); auto func = make_shared<Function>(graph, ParameterVector{W, x, b});
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto gmm = graph->get_argument(0); auto gmm = graph->get_argument(0);
...@@ -352,7 +352,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias) ...@@ -352,7 +352,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot); auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x}); auto func = make_shared<Function>(graph, ParameterVector{W, x});
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func); size_t mmb = count_ops_of_type<op::MatmulBias>(func);
...@@ -810,7 +810,7 @@ TEST(cpu_fusion, fuse_conv_relu) ...@@ -810,7 +810,7 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, ParameterVector{A, weights}); auto func = make_shared<Function>(abs_node, ParameterVector{A, weights});
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func); size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0); ASSERT_GT(cb, 0);
...@@ -1325,7 +1325,8 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim ...@@ -1325,7 +1325,8 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op. // check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func); size_t count = count_ops_of_type<op::MatmulBias>(func);
...@@ -3744,7 +3745,7 @@ TEST(cpu_fusion, gemm_mlp) ...@@ -3744,7 +3745,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string); stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss); shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func); pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func); auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3); ASSERT_EQ(mmbs, 3);
...@@ -3755,7 +3756,7 @@ TEST(cpu_fusion, fuse_fprop_bn) ...@@ -3755,7 +3756,7 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>(); pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png"); pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json"); const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
...@@ -3921,7 +3922,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model) ...@@ -3921,7 +3922,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{ {
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
const string json_path = const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json"); file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path); const string json_string = file_util::read_file_to_string(json_path);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment