Commit 88aa9e9c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

inplace compute (#1141)

* inplace compute

* fix warnings

* address bob's feedback

* bob's feedback 2

* bobs feedback 3

* address bob's feedback 4
parent e2e814e3
......@@ -25,6 +25,20 @@ namespace ngraph
/// \brief Abstract base class for annotations added to graph ops
class OpAnnotations
{
public:
void set_in_place_oi_pairs(const std::map<size_t, size_t>& oi_pairs)
{
m_in_place_oi_pairs = oi_pairs;
}
const std::map<size_t, size_t>& get_in_place_oi_pairs() const
{
return m_in_place_oi_pairs;
}
private:
//map of output-input pairs for which in-place computation is valid
std::map<size_t, size_t> m_in_place_oi_pairs;
};
}
}
......
......@@ -38,19 +38,50 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
MemoryManager mm(m_alignment);
for (shared_ptr<Node> node : function->get_ordered_ops())
{
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
std::set<const descriptor::Tensor*> reused_inputs;
if (auto op = std::dynamic_pointer_cast<op::Op>(node))
{
if (auto op_annotations = op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
auto output = &node->get_outputs().at(oi_pair.first).get_tensor();
auto input = &node->get_inputs().at(oi_pair.second).get_tensor();
if (node->liveness_free_list.count(input) != 0 &&
node->liveness_new_list.count(output) != 0)
{
NGRAPH_DEBUG << input->get_name() << " will be reused for "
<< output->get_name();
in_place_outputs.insert({output, input});
reused_inputs.insert(input);
}
}
}
}
for (descriptor::Tensor* tensor : node->liveness_new_list)
{
size_t offset = mm.allocate(tensor->size());
size_t offset = in_place_outputs.count(tensor)
? in_place_outputs.at(tensor)->get_pool_offset()
: mm.allocate(tensor->size());
tensor->set_pool_offset(offset);
}
if (!m_disable_memory_sharing)
{
for (const descriptor::Tensor* tensor : node->liveness_free_list)
{
if (reused_inputs.count(tensor) == 0)
{
mm.free(tensor->get_pool_offset());
}
}
}
}
function->set_temporary_pool_size(mm.max_allocated());
return false;
......
......@@ -28,11 +28,11 @@ namespace ngraph
class CPUOpAnnotations : public ngraph::op::util::OpAnnotations
{
public:
CPUOpAnnotations() { m_mkldnn_op = false; }
CPUOpAnnotations() {}
bool is_mkldnn_op() { return m_mkldnn_op; }
void set_mkldnn_op(bool val) { m_mkldnn_op = val; }
private:
bool m_mkldnn_op;
bool m_mkldnn_op = false;
};
}
}
......
......@@ -468,6 +468,8 @@ namespace ngraph
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
relu->set_op_annotations(op_annotations);
}
}
......
......@@ -63,6 +63,33 @@ TEST(cpu_test, unhandled_op)
ASSERT_THROW(backend->compile(f), ngraph_error);
}
TEST(cpu_test, trivial_in_place_relu)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{16, 1});
auto B = make_shared<op::Parameter>(element::f32, Shape{16, 1});
auto add = A + B;
auto relu = make_shared<op::Relu>(add);
auto f = make_shared<Function>(relu, op::ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
(backend->compile(f));
ASSERT_EQ(relu->get_outputs().at(0).get_tensor().get_pool_offset(),
add->get_outputs().at(0).get_tensor().get_pool_offset());
}
TEST(cpu_test, trivial_in_place_relu_fail)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{16, 1});
auto B = make_shared<op::Parameter>(element::f32, Shape{16, 1});
auto add = A + B;
auto relu = make_shared<op::Relu>(add);
auto add2 = relu + add;
auto f = make_shared<Function>(add2, op::ParameterVector{A, B});
auto backend = runtime::Backend::create("CPU");
(backend->compile(f));
ASSERT_NE(relu->get_outputs().at(0).get_tensor().get_pool_offset(),
add->get_outputs().at(0).get_tensor().get_pool_offset());
}
#ifdef NGRAPH_TBB_ENABLE
TEST(cpu_test, abc_tbb)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment