Commit a65b5155 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

Change FusionType to enum class and use EnumMask (#2957)

* constexpr ctor for EnumMask

* added pass properties to core passes.

* change fusion type to have better type safety.

* refactor to use enum mask.

* remove extra code.

* added constants for FusionType backward compatibility.

* spelling.

* grammar fix.
parent da1cacde
......@@ -2,6 +2,7 @@
## Passes
* `LikeReplacement` pass must be run by all transformers.
* `ngraph::pass::FusionType` is now an enum class. Constant values defined by `FusionType` are created for backward compatibility and will be removed in future releases.
## Nodes, Parameters
......
......@@ -172,7 +172,7 @@ bool ngraph::pass::BatchFusion::run_on_function(std::shared_ptr<Function> func)
const Node& node = *n;
if (TI(node) == TI(op::Concat))
{
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS)
if (m_fusion_type.is_set(FusionType::REGULAR_FUSIONS))
{
if (auto fused_conv = fuse_group_convolution(n))
{
......
......@@ -25,7 +25,7 @@ namespace ngraph
class BatchFusion : public ngraph::pass::FunctionPass
{
public:
BatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS)
BatchFusion(FusionTypeMask type = FusionType::ALL_FUSIONS)
: FunctionPass()
, m_fusion_type(type)
{
......@@ -34,7 +34,7 @@ namespace ngraph
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
ngraph::pass::FusionType m_fusion_type;
FusionTypeMask m_fusion_type;
};
}
}
......@@ -29,10 +29,10 @@ namespace ngraph
class ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite
{
public:
CoreFusion(ngraph::pass::FusionType fusions = ngraph::pass::REGULAR_FUSIONS)
CoreFusion(FusionTypeMask fusions = FusionType::REGULAR_FUSIONS)
: GraphRewrite()
{
if (fusions & ngraph::pass::REGULAR_FUSIONS)
if (fusions.is_set(FusionType::REGULAR_FUSIONS))
{
construct_relu();
construct_folded_batch_norm();
......@@ -47,7 +47,7 @@ public:
// be all supported by certain backends. In such a case, backends
// can register a FusedOpDecomposition pass after CoreFusion that will
// selectively decompose the unsupported ops back to the Core opset
if (fusions & ngraph::pass::FOP_FUSIONS)
if (fusions.is_set(FusionType::FOP_FUSIONS))
{
construct_conv_bias();
construct_conv_bias_add();
......
......@@ -35,7 +35,7 @@ namespace ngraph
class NodePass;
class CallGraphPass;
class Manager;
enum FusionType
enum class FusionType : uint32_t
{
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
......@@ -46,6 +46,18 @@ namespace ngraph
FOP_FUSIONS = 0x4,
ALL_FUSIONS = 0xFFFFFFFF
};
typedef EnumMask<FusionType> FusionTypeMask;
// These constants are for backward compatibility only, will deprecate soon.
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType DIFFERENTIABLE_FUSIONS = FusionType::DIFFERENTIABLE_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType REGULAR_FUSIONS = FusionType::REGULAR_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType FOP_FUSIONS = FusionType::FOP_FUSIONS;
NGRAPH_DEPRECATED("use FusionType enum class instead")
constexpr FusionType ALL_FUSIONS = FusionType::ALL_FUSIONS;
enum class PassProperty : uint32_t
{
// Pass requires node shapes to be static
......
......@@ -1193,7 +1193,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(CoreFusion, true, ngraph::pass, ngraph::pass::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
......
......@@ -37,16 +37,18 @@ namespace ngraph
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUFusion : public ngraph::pass::GraphRewrite
{
public:
CPUFusion(ngraph::pass::FusionType fusions = ngraph::pass::ALL_FUSIONS)
typedef ngraph::pass::FusionType FusionType;
typedef ngraph::pass::FusionTypeMask FusionTypeMask;
CPUFusion(FusionTypeMask fusions = FusionType::ALL_FUSIONS)
: GraphRewrite()
{
if (fusions & ngraph::pass::DIFFERENTIABLE_FUSIONS)
if (fusions.is_set(FusionType::DIFFERENTIABLE_FUSIONS))
{
construct_conv_bias(); // DEPRECATED - Use CoreFusion
construct_sigmoid_multiply();
}
if (fusions & ngraph::pass::REGULAR_FUSIONS)
if (fusions.is_set(FusionType::REGULAR_FUSIONS))
{
construct_matmul();
construct_matmulbias();
......
......@@ -589,7 +589,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
const Node& node = *n;
if (TI(node) == TI(op::Concat))
{
if (m_fusion_type & ngraph::pass::DIFFERENTIABLE_FUSIONS)
if (m_fusion_type.is_set(FusionType::DIFFERENTIABLE_FUSIONS))
{
if (auto fused_node = fuse_batch_mat_mul_transpose(n))
{
......@@ -597,7 +597,7 @@ bool runtime::cpu::pass::CPUBatchFusion::run_on_function(std::shared_ptr<Functio
modified = true;
}
}
if (m_fusion_type & ngraph::pass::REGULAR_FUSIONS)
if (m_fusion_type.is_set(FusionType::REGULAR_FUSIONS))
{
/*
if (auto fused_conv = fuse_group_convolution(n))
......
......@@ -36,16 +36,18 @@ namespace ngraph
class CPU_BACKEND_API CPUBatchFusion : public ngraph::pass::FunctionPass
{
public:
CPUBatchFusion(ngraph::pass::FusionType type = ngraph::pass::ALL_FUSIONS)
typedef ngraph::pass::FusionType FusionType;
typedef ngraph::pass::FusionTypeMask FusionTypeMask;
CPUBatchFusion(FusionTypeMask fusions = FusionType::ALL_FUSIONS)
: FunctionPass()
, m_fusion_type(type)
, m_fusion_type(fusions)
{
}
virtual bool
run_on_function(std::shared_ptr<ngraph::Function> function) override;
private:
ngraph::pass::FusionType m_fusion_type;
FusionTypeMask m_fusion_type;
};
}
}
......
......@@ -433,7 +433,7 @@ shared_ptr<runtime::Executable>
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.register_pass<ngraph::pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
// GetOutputElementElimination must be after CommonSubexpressionElimination
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
......
......@@ -339,7 +339,7 @@ TEST(core_fusion, conv_bias)
auto decomp_f2 = gen_f(false);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1);
......@@ -390,7 +390,7 @@ TEST(core_fusion, conv_bias_bcast_reshape)
auto decomp_f2 = gen_f(false);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBias>(decomp_f1), 1);
......@@ -442,7 +442,7 @@ TEST(core_fusion, conv_bias_add)
auto decomp_f2 = gen_f(false);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasAdd>(decomp_f1), 1);
......@@ -511,7 +511,7 @@ TEST(core_fusion, DISABLED_conv_bias_bprop)
auto decomp_f2 = gen_f(false);
pass::Manager pass_manager;
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::ALL_FUSIONS);
pass_manager.register_pass<pass::CoreFusion>(ngraph::pass::FusionType::ALL_FUSIONS);
pass_manager.run_passes(decomp_f1);
ASSERT_EQ(count_ops_of_type<op::ConvolutionBiasBackpropFiltersBias>(decomp_f1), 1);
......
......@@ -288,7 +288,7 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
auto add = dot + broadcast;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -309,7 +309,7 @@ TEST(cpu_fusion, commutative_matmul_bias)
auto add = broadcast + dot;
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(std::dynamic_pointer_cast<op::MatmulBias>(graph->get_argument(0)), nullptr);
......@@ -331,7 +331,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_bias)
auto graph = make_shared<op::Abs>(add);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x, b});
pass_manager.run_passes(func);
auto gmm = graph->get_argument(0);
......@@ -352,7 +352,7 @@ TEST(cpu_fusion, cpu_fusion_pass_matmul_no_bias)
auto graph = make_shared<op::Abs>(re_dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{W, x});
pass_manager.run_passes(func);
size_t mmb = count_ops_of_type<op::MatmulBias>(func);
......@@ -810,7 +810,7 @@ TEST(cpu_fusion, fuse_conv_relu)
auto func = make_shared<Function>(abs_node, ParameterVector{A, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func);
size_t cb = count_ops_of_type<op::ConvolutionRelu>(func);
ASSERT_GT(cb, 0);
......@@ -1325,7 +1325,8 @@ std::vector<shared_ptr<runtime::Tensor>> rnn_matrix_fusion_eval(const size_t tim
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(
pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func);
// check all of our dot/add are converted to a single MatmulBias op.
size_t count = count_ops_of_type<op::MatmulBias>(func);
......@@ -3744,7 +3745,7 @@ TEST(cpu_fusion, gemm_mlp)
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.run_passes(func);
auto mmbs = count_ops_of_type<op::MatmulBias>(func);
ASSERT_EQ(mmbs, 3);
......@@ -3755,7 +3756,7 @@ TEST(cpu_fusion, fuse_fprop_bn)
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_before_fusion.png");
pass_manager.register_pass<ngraph::pass::ReshapeElimination>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
pass_manager.register_pass<pass::VisualizeTree>("bn_fprop_after_fusion.png");
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop_b2c3h2w2.json");
const string json_string = file_util::read_file_to_string(json_path);
......@@ -3921,7 +3922,7 @@ TEST(cpu_fusion, rnn_fusion_from_json_model)
{
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPURnnMatFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::REGULAR_FUSIONS);
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
const string json_path =
file_util::path_join(SERIALIZED_ZOO, "mxnet/rnn-10-step-fusion-test.json");
const string json_string = file_util::read_file_to_string(json_path);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment