Unverified Commit 0c721561 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Liveness optimizations (#1210)

* Faster liveness.

Memory manager optimized for non-sharing of tensors.
Add pass manager profiler.

* Move pass profiler to a separate PR

* Move Memory Layout optimizations to a separate PR

* use find instead of count
parent ffe3a631
......@@ -161,7 +161,6 @@ namespace ngraph
/// Returns the shape of input i
const Shape& get_input_shape(size_t i) const;
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
......
......@@ -58,10 +58,6 @@ bool pass::DumpSorted::run_on_module(vector<shared_ptr<Function>>& functions)
out << join(outputs);
out << "\n";
for (const descriptor::Tensor* tensor : node->liveness_live_list)
{
out << " L " << tensor->get_name() << "\n";
}
for (const descriptor::Tensor* tensor : node->liveness_new_list)
{
out << " N " << tensor->get_name() << "\n";
......
......@@ -72,14 +72,13 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (auto it = ops.rbegin(); it != ops.rend(); it++)
{
shared_ptr<Node> node = *it;
node->liveness_live_list.clear();
node->liveness_new_list.clear();
node->liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls;
for (descriptor::Input& input_decl : node->get_inputs())
{
descriptor::Tensor& tensor = input_decl.get_tensor();
if (!contains(persistent_tensors, &tensor))
if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{
input_tensor_decls.insert(&tensor);
}
......@@ -89,7 +88,7 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
if (!contains(persistent_tensors, &tensor))
if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{
output_tensor_decls.insert(&tensor);
}
......@@ -102,79 +101,31 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (descriptor::Tensor* tensor_decl : all_tensor_decls)
{
if (!contains(currently_live, tensor_decl))
if (currently_live.find(tensor_decl) == currently_live.end())
{
// this is the last node that value is seen in
// delete it at the end of the op
currently_live.insert(tensor_decl);
if (output_tensors.find(tensor_decl) == output_tensors.end())
{
// Don't free output tensors
free_tensor_decls.insert(tensor_decl);
}
}
}
node->liveness_live_list = currently_live;
for (descriptor::Tensor* output_decl : output_tensor_decls)
{
if (contains(currently_live, output_decl))
auto currently_live_it = currently_live.find(output_decl);
if (currently_live_it != currently_live.end())
{
new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl);
currently_live.erase(currently_live_it);
}
}
node->liveness_free_list = free_tensor_decls;
node->liveness_new_list = new_tensor_decls;
}
// Anything marked as output must remain live for the remainder of the graph
// Add outputs to live_list and remove from free_list
unordered_set<descriptor::Tensor*> outputs;
unordered_set<descriptor::Tensor*> seen;
for (shared_ptr<Node> node : ops)
{
for (descriptor::Tensor* tensor : node->liveness_live_list)
{
if (contains(output_tensors, tensor))
{
outputs.insert(tensor);
}
}
for (descriptor::Tensor* tensor : outputs)
{
node->liveness_live_list.insert(tensor);
node->liveness_free_list.erase(tensor);
if (contains(node->liveness_new_list, tensor))
{
if (contains(seen, tensor))
{
node->liveness_new_list.erase(tensor);
}
else
{
seen.insert(tensor);
}
}
}
}
// validate_liveness(ops);
return false;
}
void pass::Liveness::validate_liveness(const list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> dead_tensors;
for (const Node* node : ops)
{
auto active = node->liveness_live_list;
active.insert(node->liveness_new_list.begin(), node->liveness_new_list.end());
active.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
for (const descriptor::Tensor* tensor : active)
{
if (contains(dead_tensors, tensor))
{
throw runtime_error("Liveness: Dead tensors intersect active tensors");
}
}
dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
}
}
......@@ -31,7 +31,4 @@ class ngraph::pass::Liveness : public FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
void validate_liveness(const std::list<Node*>& ops);
};
......@@ -66,7 +66,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>&
size_t temp_max_size = 0;
for (shared_ptr<Node> node : nodes)
{
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end());
tensors.insert(node->liveness_new_list.begin(), node->liveness_new_list.end());
}
for (descriptor::Tensor* tensor : tensors)
{
......@@ -95,37 +95,36 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>&
return false;
}
shared_ptr<Node> pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes)
unordered_set<const descriptor::Tensor*>
pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes)
{
shared_ptr<Node> largest_op = nullptr;
size_t largest_size = 0;
unordered_set<const descriptor::Tensor*> liveness_list;
unordered_set<const descriptor::Tensor*> largest_live_list;
for (shared_ptr<Node> exop : nodes)
{
size_t size = 0;
for (const descriptor::Tensor* tensor : exop->liveness_live_list)
for (const descriptor::Tensor* tensor : exop->liveness_new_list)
{
liveness_list.insert(tensor);
size += tensor->size();
}
for (const descriptor::Tensor* tensor : liveness_list)
{
size += tensor->size();
}
if (size > largest_size)
{
largest_size = size;
largest_op = exop;
largest_live_list = liveness_list;
}
}
return largest_op;
return largest_live_list;
}
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes)
{
shared_ptr<Node> largest_op = find_largest_op(nodes);
if (largest_op)
{
unordered_set<descriptor::Tensor*> largest_live;
for (descriptor::Tensor* tensor : largest_op->liveness_live_list)
{
largest_live.insert(tensor);
}
unordered_set<const descriptor::Tensor*> largest_live_list = find_largest_op(nodes);
unordered_map<const descriptor::Tensor*, size_t> age_list;
vector<const descriptor::Tensor*> tensor_set;
......@@ -161,7 +160,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_
for (const descriptor::Tensor* tensor : tensor_set)
{
int generator_weight = compute_op_weight(generator_op[tensor]);
if (contains(largest_live, tensor))
if (contains(largest_live_list, tensor))
{
file << " <tr style=\"background-color: #f0c0f0\">";
}
......@@ -177,7 +176,6 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_
}
file << "</table>\n";
}
}
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes)
......
......@@ -37,7 +37,8 @@ public:
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private:
std::shared_ptr<Node> find_largest_op(const std::list<std::shared_ptr<Node>>& nodes);
std::unordered_set<const descriptor::Tensor*>
find_largest_op(const std::list<std::shared_ptr<Node>>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
......
......@@ -48,86 +48,16 @@ TEST(liveness, constant)
auto tmp = f->get_ordered_ops();
vector<shared_ptr<Node>> sorted{tmp.begin(), tmp.end()};
ASSERT_EQ(3, sorted.size());
EXPECT_EQ(0, sorted[0]->liveness_live_list.size());
EXPECT_EQ(0, sorted[0]->liveness_new_list.size());
EXPECT_EQ(0, sorted[0]->liveness_free_list.size());
//op::Negative is live on output to op::Result
EXPECT_EQ(1, sorted[1]->liveness_live_list.size());
//op::Negative is new
EXPECT_EQ(1, sorted[1]->liveness_new_list.size());
EXPECT_EQ(0, sorted[1]->liveness_free_list.size());
//op::Negative is live on input to op::Result
EXPECT_EQ(1, sorted[2]->liveness_live_list.size());
EXPECT_EQ(0, sorted[2]->liveness_new_list.size());
//op::Negative is freed
EXPECT_EQ(1, sorted[2]->liveness_free_list.size());
}
TEST(liveness, liveness)
{
string image = "liveness.png";
string dump_file = "liveness.txt";
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(image);
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func);
auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted)
// {
// NGRAPH_INFO << *node;
// for (const descriptor::Tensor* tensor : node->liveness_live_list)
// {
// NGRAPH_INFO << " " << *tensor;
// }
// }
// auto x = ng.variable(axes=[]).named('x');
// auto y = ng.variable(axes=[]).named('y');
// auto w1 = ng.variable(axes=[]).named('w1');
// auto w2 = ng.variable(axes=[]).named('w2');
// auto x2 = x * w1;
// auto x3 = (x2 * w2).named('result');
// auto cost = x3 - y;
// auto dw1 = ng.deriv(cost, w1);
// auto dw2 = ng.deriv(cost, w2);
// auto upd1 = ng.assign(w1, w1 + dw1);
// auto upd2 = ng.assign(w2, w2 + dw2);
// auto seq_stuff = ng.sequential([upd1, upd2, x3]);
// auto exc = ex.executor(seq_stuff);
// return exc;
// lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory()
// for i, node in enumerate(lg.liveness_nodes):
// print i, node
// for node in lg.liveness_nodes:
// for var1 in node.live_list:
// assert var1.buffer_pool_offset is not None
// for var2 in node.live_list:
// if var1 != var2:
// if var1.buffer_pool_offset < var2.buffer_pool_offset:
// assert var1.buffer_pool_offset + var1.size <= var2.buffer_pool_offset
// else:
// assert var2.buffer_pool_offset + var2.size <= var1.buffer_pool_offset
// // for o in egraph.computations:
// // print o.values
// print("max memory {}".format(lg.memory_footprint()))
// print("worst case memory {}".format(lg.worst_case_memory_usage()))
// print("memory efficiency {}".format(lg.memory_efficiency()))
// // // print lg.liveness_json()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment