Commit 2a0e43ef authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Propagate input buffers for passthrough kernels (#1312)

parent ef309cf6
...@@ -565,7 +565,7 @@ bool ngraph::possibly_overwritten(Node* node) ...@@ -565,7 +565,7 @@ bool ngraph::possibly_overwritten(Node* node)
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{ {
if (input->get_index() == oi_pair.second) if (input->get_index() == oi_pair.input)
{ {
return true; return true;
} }
......
...@@ -22,23 +22,37 @@ namespace ngraph ...@@ -22,23 +22,37 @@ namespace ngraph
{ {
namespace util namespace util
{ {
/// \brief Abstract base class for annotations added to graph ops struct oi_pair
{
size_t output;
size_t input;
bool destructive;
};
/// \brief Base class for annotations added to graph ops
class OpAnnotations class OpAnnotations
{ {
public: public:
void set_in_place_oi_pairs(const std::map<size_t, size_t>& oi_pairs) void add_in_place_oi_pair(const struct oi_pair& oi)
{ {
m_in_place_oi_pairs = oi_pairs; for (auto e : m_in_place_oi_pairs)
{
if (e.input == oi.input || e.output == oi.output)
{
throw ngraph_error("In_place hint conflicts with an existing entry");
}
}
m_in_place_oi_pairs.emplace_back(oi);
} }
const std::map<size_t, size_t>& get_in_place_oi_pairs() const const std::vector<struct oi_pair>& get_in_place_oi_pairs() const
{ {
return m_in_place_oi_pairs; return m_in_place_oi_pairs;
} }
private: private:
//map of output-input pairs for which in-place computation is valid //map of output-input pairs for which in-place computation is valid
std::map<size_t, size_t> m_in_place_oi_pairs; std::vector<struct oi_pair> m_in_place_oi_pairs;
}; };
} }
} }
......
...@@ -47,8 +47,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -47,8 +47,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
{ {
for (auto oi_pair : op_annotations->get_in_place_oi_pairs()) for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{ {
auto output = &node->get_outputs().at(oi_pair.first).get_tensor(); auto output = &node->get_outputs().at(oi_pair.output).get_tensor();
auto input = &node->get_inputs().at(oi_pair.second).get_tensor(); auto input = &node->get_inputs().at(oi_pair.input).get_tensor();
if (node->liveness_free_list.count(input) != 0 && if (node->liveness_free_list.count(input) != 0 &&
node->liveness_new_list.count(output) != 0) node->liveness_new_list.count(output) != 0)
......
...@@ -659,6 +659,7 @@ using namespace ngraph::runtime; ...@@ -659,6 +659,7 @@ using namespace ngraph::runtime;
ss << "((" << type << "*)(inputs[" << arg_index << "]))"; ss << "((" << type << "*)(inputs[" << arg_index << "]))";
m_variable_name_map[tv->get_tensor().get_name()] = ss.str(); m_variable_name_map[tv->get_tensor().get_name()] = ss.str();
param_index_map[tv->get_tensor().get_name()] = arg_index; param_index_map[tv->get_tensor().get_name()] = arg_index;
propagate_in_place_input(&param->get_outputs().at(i), ss.str());
arg_index++; arg_index++;
} }
} }
...@@ -976,6 +977,41 @@ using namespace ngraph::runtime; ...@@ -976,6 +977,41 @@ using namespace ngraph::runtime;
} }
} }
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
ngraph::descriptor::Output* output, std::string input_name)
{
auto it = output;
auto propagate_further = false;
do
{
propagate_further = false;
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output())
{
break;
}
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (oi_pair.input == input->get_index() && !oi_pair.destructive)
{
size_t output_index = oi_pair.output;
auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
m_variable_name_map[output_tensor.get_name()] = input_name;
it = &c_op->get_outputs().at(output_index);
propagate_further = true;
}
}
}
}
} while (propagate_further);
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name) ngraph::descriptor::Output* res_src_output, std::string output_name)
{ {
...@@ -995,18 +1031,21 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output( ...@@ -995,18 +1031,21 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
} }
if (auto op_annotations = arg->get_op_annotations()) if (auto op_annotations = arg->get_op_annotations())
{ {
auto oi_pairs = op_annotations->get_in_place_oi_pairs(); for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
if (oi_pairs.count(it->get_index()) != 0)
{ {
size_t input_index = oi_pairs.at(it->get_index()); if (oi_pair.output == it->get_index())
auto& input_tensor = arg->get_inputs().at(input_index).get_tensor();
if (input_tensor.get_pool_offset() == offset &&
!arg->get_inputs().at(input_index).get_output().get_node()->is_parameter())
{ {
NGRAPH_DEBUG << "Reusing " << output_name << " for " << input_tensor.get_name(); size_t input_index = oi_pair.input;
m_variable_name_map[input_tensor.get_name()] = output_name; auto& input_tensor = arg->get_inputs().at(input_index).get_tensor();
it = &arg->get_inputs().at(input_index).get_output(); if (input_tensor.get_pool_offset() == offset &&
propagate_further = true; !arg->get_inputs().at(input_index).get_output().get_node()->is_parameter())
{
NGRAPH_DEBUG << "Reusing " << output_name << " for "
<< input_tensor.get_name();
m_variable_name_map[input_tensor.get_name()] = output_name;
it = &arg->get_inputs().at(input_index).get_output();
propagate_further = true;
}
} }
} }
} }
......
...@@ -116,6 +116,12 @@ namespace ngraph ...@@ -116,6 +116,12 @@ namespace ngraph
void compile(); void compile();
private: private:
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
std::string input_name);
// For in-place kernels, propagate function output buffers to
// internal ops
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output, void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
std::string output_name); std::string output_name);
void emit_debug_function_entry(codegen::CodeWriter& writer, void emit_debug_function_entry(codegen::CodeWriter& writer,
......
...@@ -215,8 +215,7 @@ namespace ngraph ...@@ -215,8 +215,7 @@ namespace ngraph
op_annotations->set_mkldnn_op(true); op_annotations->set_mkldnn_op(true);
const int ADD_INPUT = 3; const int ADD_INPUT = 3;
// Accumulates conv into the second input of the unfused add // Accumulates conv into the second input of the unfused add
std::map<size_t, size_t> oi_pairs = {{0, ADD_INPUT}}; op_annotations->add_in_place_oi_pair({0, ADD_INPUT, true});
op_annotations->set_in_place_oi_pairs(oi_pairs);
convolution->set_op_annotations(op_annotations); convolution->set_op_annotations(op_annotations);
} }
} }
...@@ -479,8 +478,7 @@ namespace ngraph ...@@ -479,8 +478,7 @@ namespace ngraph
if (get_user_count(node->get_argument(0).get()) == 1) if (get_user_count(node->get_argument(0).get()) == 1)
{ {
// Safe to overwrite input // Safe to overwrite input
std::map<size_t, size_t> oi_pairs = {{0, 0}}; op_annotations->add_in_place_oi_pair({0, 0, true});
op_annotations->set_in_place_oi_pairs(oi_pairs);
} }
relu->set_op_annotations(op_annotations); relu->set_op_annotations(op_annotations);
} }
...@@ -564,14 +562,12 @@ namespace ngraph ...@@ -564,14 +562,12 @@ namespace ngraph
auto arg = reshape->get_argument(0); auto arg = reshape->get_argument(0);
// we need to copy input data if reshape modifies the data or inputs are // we need to copy input data if reshape modifies the data or inputs are
// not in the memory pool, or has output users. // not in the memory pool, or has output users.
bool need_copy = bool need_copy = reshape->get_is_transpose() || arg->is_constant();
reshape->get_is_transpose() || arg->is_parameter() || arg->is_constant();
if (!need_copy) if (!need_copy)
{ {
// map output to the input memory // map output to the input memory
std::map<size_t, size_t> oi_pairs = {{0, 0}}; op_annotations->add_in_place_oi_pair({0, 0, false});
op_annotations->set_in_place_oi_pairs(oi_pairs);
reshape->set_op_annotations(op_annotations); reshape->set_op_annotations(op_annotations);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment