Commit 2a0e43ef authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Robert Kimball

Propagate input buffers for passthrough kernels (#1312)

parent ef309cf6
......@@ -565,7 +565,7 @@ bool ngraph::possibly_overwritten(Node* node)
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (input->get_index() == oi_pair.second)
if (input->get_index() == oi_pair.input)
{
return true;
}
......
......@@ -22,23 +22,37 @@ namespace ngraph
{
namespace util
{
/// \brief Abstract base class for annotations added to graph ops
struct oi_pair
{
size_t output;
size_t input;
bool destructive;
};
/// \brief Base class for annotations added to graph ops
class OpAnnotations
{
public:
void set_in_place_oi_pairs(const std::map<size_t, size_t>& oi_pairs)
void add_in_place_oi_pair(const struct oi_pair& oi)
{
for (auto e : m_in_place_oi_pairs)
{
m_in_place_oi_pairs = oi_pairs;
if (e.input == oi.input || e.output == oi.output)
{
throw ngraph_error("In_place hint conflicts with an existing entry");
}
}
m_in_place_oi_pairs.emplace_back(oi);
}
const std::map<size_t, size_t>& get_in_place_oi_pairs() const
const std::vector<struct oi_pair>& get_in_place_oi_pairs() const
{
return m_in_place_oi_pairs;
}
private:
//map of output-input pairs for which in-place computation is valid
std::map<size_t, size_t> m_in_place_oi_pairs;
std::vector<struct oi_pair> m_in_place_oi_pairs;
};
}
}
......
......@@ -47,8 +47,8 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<ngraph::Function> function)
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
auto output = &node->get_outputs().at(oi_pair.first).get_tensor();
auto input = &node->get_inputs().at(oi_pair.second).get_tensor();
auto output = &node->get_outputs().at(oi_pair.output).get_tensor();
auto input = &node->get_inputs().at(oi_pair.input).get_tensor();
if (node->liveness_free_list.count(input) != 0 &&
node->liveness_new_list.count(output) != 0)
......
......@@ -659,6 +659,7 @@ using namespace ngraph::runtime;
ss << "((" << type << "*)(inputs[" << arg_index << "]))";
m_variable_name_map[tv->get_tensor().get_name()] = ss.str();
param_index_map[tv->get_tensor().get_name()] = arg_index;
propagate_in_place_input(&param->get_outputs().at(i), ss.str());
arg_index++;
}
}
......@@ -976,6 +977,41 @@ using namespace ngraph::runtime;
}
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_input(
ngraph::descriptor::Output* output, std::string input_name)
{
auto it = output;
auto propagate_further = false;
do
{
propagate_further = false;
for (auto input : it->get_inputs())
{
auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
if (!c_op || c_op->is_output())
{
break;
}
if (auto op_annotations = c_op->get_op_annotations())
{
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
if (oi_pair.input == input->get_index() && !oi_pair.destructive)
{
size_t output_index = oi_pair.output;
auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
m_variable_name_map[output_tensor.get_name()] = input_name;
it = &c_op->get_outputs().at(output_index);
propagate_further = true;
}
}
}
}
} while (propagate_further);
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name)
{
......@@ -995,21 +1031,24 @@ void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
}
if (auto op_annotations = arg->get_op_annotations())
{
auto oi_pairs = op_annotations->get_in_place_oi_pairs();
if (oi_pairs.count(it->get_index()) != 0)
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
{
size_t input_index = oi_pairs.at(it->get_index());
if (oi_pair.output == it->get_index())
{
size_t input_index = oi_pair.input;
auto& input_tensor = arg->get_inputs().at(input_index).get_tensor();
if (input_tensor.get_pool_offset() == offset &&
!arg->get_inputs().at(input_index).get_output().get_node()->is_parameter())
{
NGRAPH_DEBUG << "Reusing " << output_name << " for " << input_tensor.get_name();
NGRAPH_DEBUG << "Reusing " << output_name << " for "
<< input_tensor.get_name();
m_variable_name_map[input_tensor.get_name()] = output_name;
it = &arg->get_inputs().at(input_index).get_output();
propagate_further = true;
}
}
}
}
} while (propagate_further);
}
......
......@@ -116,6 +116,12 @@ namespace ngraph
void compile();
private:
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
std::string input_name);
// For in-place kernels, propagate function output buffers to
// internal ops
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
std::string output_name);
void emit_debug_function_entry(codegen::CodeWriter& writer,
......
......@@ -215,8 +215,7 @@ namespace ngraph
op_annotations->set_mkldnn_op(true);
const int ADD_INPUT = 3;
// Accumulates conv into the second input of the unfused add
std::map<size_t, size_t> oi_pairs = {{0, ADD_INPUT}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
op_annotations->add_in_place_oi_pair({0, ADD_INPUT, true});
convolution->set_op_annotations(op_annotations);
}
}
......@@ -479,8 +478,7 @@ namespace ngraph
if (get_user_count(node->get_argument(0).get()) == 1)
{
// Safe to overwrite input
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
op_annotations->add_in_place_oi_pair({0, 0, true});
}
relu->set_op_annotations(op_annotations);
}
......@@ -564,14 +562,12 @@ namespace ngraph
auto arg = reshape->get_argument(0);
// we need to copy input data if reshape modifies the data or inputs are
// not in the memory pool, or has output users.
bool need_copy =
reshape->get_is_transpose() || arg->is_parameter() || arg->is_constant();
bool need_copy = reshape->get_is_transpose() || arg->is_constant();
if (!need_copy)
{
// map output to the input memory
std::map<size_t, size_t> oi_pairs = {{0, 0}};
op_annotations->set_in_place_oi_pairs(oi_pairs);
op_annotations->add_in_place_oi_pair({0, 0, false});
reshape->set_op_annotations(op_annotations);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment