Commit ad3a1b6b authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Sang Ik Lee

- Bug fix to pick a reference to staleness instead of a copy (#2613)

- Enable caching irrespective of cacheability hints when reuse_memory is disabled
parent 37b95a02
......@@ -894,6 +894,7 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability
if (computes_result(node.get()) || possibly_overwritten(node.get()))
{
writer << " || 1";
......@@ -1269,7 +1270,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
auto output_tensor = &param->get_outputs().at(i).get_tensor();
auto tensor_set = get_tensor_set(output_tensor);
auto stale = tensor_stale[output_tensor->get_name()];
auto& stale = tensor_stale[output_tensor->get_name()];
// process all tensors in the set containing the output tensor of the parameter
for (auto& ele_t : tensor_set)
{
......@@ -1334,6 +1335,8 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
handler->second(this, node.get(), in, out);
auto cacheable = true;
auto reuse_memory = pass_config.get_pass_attribute("CPUMemoryAssignment::ReuseMemory") ||
pass_config.get_pass_attribute("ReuseMemory");
if (node->is_op())
{
auto op = std::static_pointer_cast<ngraph::op::Op>(node);
......@@ -1342,7 +1345,9 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
}
bool disable_caching =
!cacheable || computes_result(node.get()) || possibly_overwritten(node.get());
(reuse_memory &&
!cacheable) // Check cacheability only if we are reusing intermediate tensors
|| computes_result(node.get()) || possibly_overwritten(node.get());
vector<reference_wrapper<bool>> in_stale, out_stale;
for (const auto& name : in_names)
......@@ -1358,7 +1363,14 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
}
for (const auto& name : out_names)
{
out_stale.emplace_back(tensor_stale[name]);
if (tensor_alias.count(name))
{
out_stale.emplace_back(tensor_stale[tensor_alias[name]]);
}
else
{
out_stale.emplace_back(tensor_stale[name]);
}
}
function<bool(CPURuntimeContext*)> enable;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment