Commit b6387054 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Do not allow in place slice if the arg is in CONSTANT buffer set. (#2625)

* Do not allow in place slice if the arg is in CONSTANT buffer set.

* Add a unit test.

* Address PR feedback.
parent c0f29b47
...@@ -518,10 +518,21 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared ...@@ -518,10 +518,21 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared
} }
if (!no_in_place) if (!no_in_place)
{ {
auto bufferID = get_bufferID(input_tensor);
auto input_buffer_it = m_bufferID_to_tensorSets.find(bufferID);
NGRAPH_ASSERT(input_buffer_it !=
m_bufferID_to_tensorSets.end());
if (node->description() == "Slice") if (node->description() == "Slice")
{ {
// build in place slice chain if (input_buffer_it->second.first !=
in_place_slice_chain.insert(output_tensor); CPUTensorRole::CONSTANT)
{
// build in place slice chain
in_place_slice_chain.insert(output_tensor);
input_buffer_it->second.second.insert(output_tensor);
m_tensor_to_bufferID[output_tensor] = bufferID;
}
} }
else else
{ {
...@@ -531,13 +542,9 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared ...@@ -531,13 +542,9 @@ void runtime::cpu::pass::CPUMemoryAssignment::build_buffer_sets_maps(list<shared
{ {
in_place_slice_chain.insert(output_tensor); in_place_slice_chain.insert(output_tensor);
} }
input_buffer_it->second.second.insert(output_tensor);
m_tensor_to_bufferID[output_tensor] = bufferID;
} }
auto bufferID = get_bufferID(input_tensor);
auto input_buffer_it = m_bufferID_to_tensorSets.find(bufferID);
NGRAPH_ASSERT(input_buffer_it !=
m_bufferID_to_tensorSets.end());
input_buffer_it->second.second.insert(output_tensor);
m_tensor_to_bufferID[output_tensor] = bufferID;
} }
} }
} }
......
...@@ -873,6 +873,30 @@ TEST(cpu_test, memory_reuse_in_place_slice_after_in_place_concat) ...@@ -873,6 +873,30 @@ TEST(cpu_test, memory_reuse_in_place_slice_after_in_place_concat)
EXPECT_TRUE(test::all_close_f((vector<float>{3, 7}), read_vector<float>(result))); EXPECT_TRUE(test::all_close_f((vector<float>{3, 7}), read_vector<float>(result)));
} }
TEST(cpu_test, memory_reuse_in_place_slice_after_in_place_reshape_from_constant)
{
Shape shape_a{2, 1, 2, 2};
Shape shape_r{2, 1, 2, 2};
vector<float> a_data(shape_size(shape_a));
iota(a_data.begin(), a_data.end(), 1);
auto A = op::Constant::create(element::f32, shape_a, a_data);
auto reshape = make_shared<op::Reshape>(A, AxisVector{0, 1, 2, 3}, shape_r);
Shape shape{1, 1, 2, 2};
auto slice = make_shared<op::Slice>(reshape, Coordinate{1, 0, 0, 0}, Coordinate{2, 1, 2, 2});
auto neg = make_shared<op::Negative>(slice);
auto f = make_shared<Function>(neg, ParameterVector{});
auto backend = runtime::Backend::create("CPU");
auto result = backend->create_tensor(element::f32, shape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {});
EXPECT_TRUE(test::all_close_f(
vector<float>{-5., -6., -7., -8.}, read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
TEST(cpu_test, convert_inplace) TEST(cpu_test, convert_inplace)
{ {
Shape shape{2, 2}; Shape shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment