Commit 429776d2 authored by Amy Zhuang's avatar Amy Zhuang Committed by omarkanawi

Fold Constant + ConvertLayout to reduce memory footprint. (#3465)

* Fold Constant + ConvertLayout.

* Address PR Feedback.

* No folding if data layout is padded.

* Add unit test.

* Fix style error.
parent d218ccf9
...@@ -1244,6 +1244,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes( ...@@ -1244,6 +1244,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS_WITH_ARGS( REGISTER_KNOBBED_PASS_WITH_ARGS(
CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map()); CommonSubexpressionElimination, true, ngraph::pass, runtime::cpu::get_cse_handlers_map());
REGISTER_KNOBBED_PASS(CPUPostLayoutOptimizations, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUPostLayoutOptimizations, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUConvertLayoutConstantFolding, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUMemoryOptimization, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUMemoryOptimization, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(GetOutputElementElimination, false, ngraph::pass); REGISTER_KNOBBED_PASS(GetOutputElementElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS( REGISTER_KNOBBED_PASS_WITH_ARGS(
......
...@@ -27,6 +27,7 @@ namespace ngraph ...@@ -27,6 +27,7 @@ namespace ngraph
namespace pass namespace pass
{ {
class CPUPostLayoutOptimizations; class CPUPostLayoutOptimizations;
class CPUConvertLayoutConstantFolding;
} }
} }
} }
...@@ -47,3 +48,11 @@ public: ...@@ -47,3 +48,11 @@ public:
void construct_slice_convertLayout_fusion(); void construct_slice_convertLayout_fusion();
void construct_reshape_convertLayout_fusion(); void construct_reshape_convertLayout_fusion();
}; };
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUConvertLayoutConstantFolding
: public ngraph::pass::FunctionPass
{
public:
CPUConvertLayoutConstantFolding() {}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
...@@ -1051,6 +1051,29 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items) ...@@ -1051,6 +1051,29 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items)
unset_environment("NGRAPH_CPU_CONCURRENCY"); unset_environment("NGRAPH_CPU_CONCURRENCY");
} }
TEST(cpu_test, constant_convertlayout)
{
Shape data_shape{1, 64, 56, 56};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
Shape weights_shape{64, 64, 3, 3};
test::Uniform<float> rng(-100.0f, 100.0f);
vector<float> values_in(shape_size(weights_shape));
rng.initialize(values_in);
auto weights = make_shared<op::Constant>(element::f32, weights_shape, values_in);
Shape bias_shape{64};
auto bias = make_shared<op::Parameter>(element::f32, bias_shape);
auto conv = std::make_shared<op::Convolution>(data, weights, Strides{1, 1}, Strides{1, 1});
auto convbias = make_shared<op::ConvolutionBias>(conv, bias);
auto f = make_shared<Function>(convbias, ParameterVector{data, bias});
auto backend = runtime::Backend::create("CPU");
auto handle = backend->compile(f);
size_t convert_layout = count_ops_of_type<runtime::cpu::op::ConvertLayout>(f);
ASSERT_EQ(convert_layout, 1);
}
TEST(cpu_test, constant_reshape) TEST(cpu_test, constant_reshape)
{ {
Shape shape_in{2, 4}; Shape shape_in{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment