Unverified Commit 0c721561 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Liveness optimizations (#1210)

* Faster liveness.

Memory manager optimized for non-sharing of tensors.
Add pass manager profiler.

* Move pass profiler to a separate PR

* Move Memory Layout optimizations to a separate PR

* use find instead of count
parent ffe3a631
...@@ -161,7 +161,6 @@ namespace ngraph ...@@ -161,7 +161,6 @@ namespace ngraph
/// Returns the shape of input i /// Returns the shape of input i
const Shape& get_input_shape(size_t i) const; const Shape& get_input_shape(size_t i) const;
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list; std::unordered_set<descriptor::Tensor*> liveness_free_list;
......
...@@ -58,10 +58,6 @@ bool pass::DumpSorted::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -58,10 +58,6 @@ bool pass::DumpSorted::run_on_module(vector<shared_ptr<Function>>& functions)
out << join(outputs); out << join(outputs);
out << "\n"; out << "\n";
for (const descriptor::Tensor* tensor : node->liveness_live_list)
{
out << " L " << tensor->get_name() << "\n";
}
for (const descriptor::Tensor* tensor : node->liveness_new_list) for (const descriptor::Tensor* tensor : node->liveness_new_list)
{ {
out << " N " << tensor->get_name() << "\n"; out << " N " << tensor->get_name() << "\n";
......
...@@ -72,14 +72,13 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -72,14 +72,13 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (auto it = ops.rbegin(); it != ops.rend(); it++) for (auto it = ops.rbegin(); it != ops.rend(); it++)
{ {
shared_ptr<Node> node = *it; shared_ptr<Node> node = *it;
node->liveness_live_list.clear();
node->liveness_new_list.clear(); node->liveness_new_list.clear();
node->liveness_free_list.clear(); node->liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls; unordered_set<descriptor::Tensor*> input_tensor_decls;
for (descriptor::Input& input_decl : node->get_inputs()) for (descriptor::Input& input_decl : node->get_inputs())
{ {
descriptor::Tensor& tensor = input_decl.get_tensor(); descriptor::Tensor& tensor = input_decl.get_tensor();
if (!contains(persistent_tensors, &tensor)) if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{ {
input_tensor_decls.insert(&tensor); input_tensor_decls.insert(&tensor);
} }
...@@ -89,7 +88,7 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -89,7 +88,7 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (size_t i = 0; i < node->get_output_size(); ++i) for (size_t i = 0; i < node->get_output_size(); ++i)
{ {
descriptor::Tensor& tensor = node->get_output_tensor(i); descriptor::Tensor& tensor = node->get_output_tensor(i);
if (!contains(persistent_tensors, &tensor)) if (persistent_tensors.find(&tensor) == persistent_tensors.end())
{ {
output_tensor_decls.insert(&tensor); output_tensor_decls.insert(&tensor);
} }
...@@ -102,79 +101,31 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function) ...@@ -102,79 +101,31 @@ bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
for (descriptor::Tensor* tensor_decl : all_tensor_decls) for (descriptor::Tensor* tensor_decl : all_tensor_decls)
{ {
if (!contains(currently_live, tensor_decl)) if (currently_live.find(tensor_decl) == currently_live.end())
{ {
// this is the last node that value is seen in // this is the last node that value is seen in
// delete it at the end of the op // delete it at the end of the op
currently_live.insert(tensor_decl); currently_live.insert(tensor_decl);
if (output_tensors.find(tensor_decl) == output_tensors.end())
{
// Don't free output tensors
free_tensor_decls.insert(tensor_decl); free_tensor_decls.insert(tensor_decl);
} }
} }
}
node->liveness_live_list = currently_live;
for (descriptor::Tensor* output_decl : output_tensor_decls) for (descriptor::Tensor* output_decl : output_tensor_decls)
{ {
if (contains(currently_live, output_decl)) auto currently_live_it = currently_live.find(output_decl);
if (currently_live_it != currently_live.end())
{ {
new_tensor_decls.insert(output_decl); new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl); currently_live.erase(currently_live_it);
} }
} }
node->liveness_free_list = free_tensor_decls; node->liveness_free_list = free_tensor_decls;
node->liveness_new_list = new_tensor_decls; node->liveness_new_list = new_tensor_decls;
} }
// Anything marked as output must remain live for the remainder of the graph
// Add outputs to live_list and remove from free_list
unordered_set<descriptor::Tensor*> outputs;
unordered_set<descriptor::Tensor*> seen;
for (shared_ptr<Node> node : ops)
{
for (descriptor::Tensor* tensor : node->liveness_live_list)
{
if (contains(output_tensors, tensor))
{
outputs.insert(tensor);
}
}
for (descriptor::Tensor* tensor : outputs)
{
node->liveness_live_list.insert(tensor);
node->liveness_free_list.erase(tensor);
if (contains(node->liveness_new_list, tensor))
{
if (contains(seen, tensor))
{
node->liveness_new_list.erase(tensor);
}
else
{
seen.insert(tensor);
}
}
}
}
// validate_liveness(ops);
return false; return false;
} }
void pass::Liveness::validate_liveness(const list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> dead_tensors;
for (const Node* node : ops)
{
auto active = node->liveness_live_list;
active.insert(node->liveness_new_list.begin(), node->liveness_new_list.end());
active.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
for (const descriptor::Tensor* tensor : active)
{
if (contains(dead_tensors, tensor))
{
throw runtime_error("Liveness: Dead tensors intersect active tensors");
}
}
dead_tensors.insert(node->liveness_free_list.begin(), node->liveness_free_list.end());
}
}
...@@ -31,7 +31,4 @@ class ngraph::pass::Liveness : public FunctionPass ...@@ -31,7 +31,4 @@ class ngraph::pass::Liveness : public FunctionPass
{ {
public: public:
bool run_on_function(std::shared_ptr<ngraph::Function>) override; bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
void validate_liveness(const std::list<Node*>& ops);
}; };
...@@ -66,7 +66,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& ...@@ -66,7 +66,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>&
size_t temp_max_size = 0; size_t temp_max_size = 0;
for (shared_ptr<Node> node : nodes) for (shared_ptr<Node> node : nodes)
{ {
tensors.insert(node->liveness_live_list.begin(), node->liveness_live_list.end()); tensors.insert(node->liveness_new_list.begin(), node->liveness_new_list.end());
} }
for (descriptor::Tensor* tensor : tensors) for (descriptor::Tensor* tensor : tensors)
{ {
...@@ -95,37 +95,36 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>& ...@@ -95,37 +95,36 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>&
return false; return false;
} }
shared_ptr<Node> pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes) unordered_set<const descriptor::Tensor*>
pass::MemoryVisualize::find_largest_op(const list<shared_ptr<Node>>& nodes)
{ {
shared_ptr<Node> largest_op = nullptr;
size_t largest_size = 0; size_t largest_size = 0;
unordered_set<const descriptor::Tensor*> liveness_list;
unordered_set<const descriptor::Tensor*> largest_live_list;
for (shared_ptr<Node> exop : nodes) for (shared_ptr<Node> exop : nodes)
{ {
size_t size = 0; size_t size = 0;
for (const descriptor::Tensor* tensor : exop->liveness_live_list) for (const descriptor::Tensor* tensor : exop->liveness_new_list)
{
liveness_list.insert(tensor);
size += tensor->size();
}
for (const descriptor::Tensor* tensor : liveness_list)
{ {
size += tensor->size(); size += tensor->size();
} }
if (size > largest_size) if (size > largest_size)
{ {
largest_size = size; largest_size = size;
largest_op = exop; largest_live_list = liveness_list;
} }
} }
return largest_op; return largest_live_list;
} }
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes) void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ptr<Node>>& nodes)
{ {
shared_ptr<Node> largest_op = find_largest_op(nodes); unordered_set<const descriptor::Tensor*> largest_live_list = find_largest_op(nodes);
if (largest_op)
{
unordered_set<descriptor::Tensor*> largest_live;
for (descriptor::Tensor* tensor : largest_op->liveness_live_list)
{
largest_live.insert(tensor);
}
unordered_map<const descriptor::Tensor*, size_t> age_list; unordered_map<const descriptor::Tensor*, size_t> age_list;
vector<const descriptor::Tensor*> tensor_set; vector<const descriptor::Tensor*> tensor_set;
...@@ -161,7 +160,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ ...@@ -161,7 +160,7 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_
for (const descriptor::Tensor* tensor : tensor_set) for (const descriptor::Tensor* tensor : tensor_set)
{ {
int generator_weight = compute_op_weight(generator_op[tensor]); int generator_weight = compute_op_weight(generator_op[tensor]);
if (contains(largest_live, tensor)) if (contains(largest_live_list, tensor))
{ {
file << " <tr style=\"background-color: #f0c0f0\">"; file << " <tr style=\"background-color: #f0c0f0\">";
} }
...@@ -177,7 +176,6 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_ ...@@ -177,7 +176,6 @@ void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const list<shared_
} }
file << "</table>\n"; file << "</table>\n";
}
} }
void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes) void pass::MemoryVisualize::draw_histogram(ostream& file, const list<shared_ptr<Node>>& nodes)
......
...@@ -37,7 +37,8 @@ public: ...@@ -37,7 +37,8 @@ public:
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override; virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
private: private:
std::shared_ptr<Node> find_largest_op(const std::list<std::shared_ptr<Node>>& nodes); std::unordered_set<const descriptor::Tensor*>
find_largest_op(const std::list<std::shared_ptr<Node>>& nodes);
void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_tensor_weight(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_histogram(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes); void draw_op_influence(std::ostream& file, const std::list<std::shared_ptr<Node>>& nodes);
......
...@@ -48,86 +48,16 @@ TEST(liveness, constant) ...@@ -48,86 +48,16 @@ TEST(liveness, constant)
auto tmp = f->get_ordered_ops(); auto tmp = f->get_ordered_ops();
vector<shared_ptr<Node>> sorted{tmp.begin(), tmp.end()}; vector<shared_ptr<Node>> sorted{tmp.begin(), tmp.end()};
ASSERT_EQ(3, sorted.size()); ASSERT_EQ(3, sorted.size());
EXPECT_EQ(0, sorted[0]->liveness_live_list.size());
EXPECT_EQ(0, sorted[0]->liveness_new_list.size()); EXPECT_EQ(0, sorted[0]->liveness_new_list.size());
EXPECT_EQ(0, sorted[0]->liveness_free_list.size()); EXPECT_EQ(0, sorted[0]->liveness_free_list.size());
//op::Negative is live on output to op::Result //op::Negative is live on output to op::Result
EXPECT_EQ(1, sorted[1]->liveness_live_list.size());
//op::Negative is new //op::Negative is new
EXPECT_EQ(1, sorted[1]->liveness_new_list.size()); EXPECT_EQ(1, sorted[1]->liveness_new_list.size());
EXPECT_EQ(0, sorted[1]->liveness_free_list.size()); EXPECT_EQ(0, sorted[1]->liveness_free_list.size());
//op::Negative is live on input to op::Result //op::Negative is live on input to op::Result
EXPECT_EQ(1, sorted[2]->liveness_live_list.size());
EXPECT_EQ(0, sorted[2]->liveness_new_list.size()); EXPECT_EQ(0, sorted[2]->liveness_new_list.size());
//op::Negative is freed //op::Negative is freed
EXPECT_EQ(1, sorted[2]->liveness_free_list.size()); EXPECT_EQ(1, sorted[2]->liveness_free_list.size());
} }
TEST(liveness, liveness)
{
string image = "liveness.png";
string dump_file = "liveness.txt";
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>(image);
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
shared_ptr<Function> func = make_test_graph();
pass_manager.run_passes(func);
auto sorted = func->get_ordered_ops();
// for (const Node* node : sorted)
// {
// NGRAPH_INFO << *node;
// for (const descriptor::Tensor* tensor : node->liveness_live_list)
// {
// NGRAPH_INFO << " " << *tensor;
// }
// }
// auto x = ng.variable(axes=[]).named('x');
// auto y = ng.variable(axes=[]).named('y');
// auto w1 = ng.variable(axes=[]).named('w1');
// auto w2 = ng.variable(axes=[]).named('w2');
// auto x2 = x * w1;
// auto x3 = (x2 * w2).named('result');
// auto cost = x3 - y;
// auto dw1 = ng.deriv(cost, w1);
// auto dw2 = ng.deriv(cost, w2);
// auto upd1 = ng.assign(w1, w1 + dw1);
// auto upd2 = ng.assign(w2, w2 + dw2);
// auto seq_stuff = ng.sequential([upd1, upd2, x3]);
// auto exc = ex.executor(seq_stuff);
// return exc;
// lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory()
// for i, node in enumerate(lg.liveness_nodes):
// print i, node
// for node in lg.liveness_nodes:
// for var1 in node.live_list:
// assert var1.buffer_pool_offset is not None
// for var2 in node.live_list:
// if var1 != var2:
// if var1.buffer_pool_offset < var2.buffer_pool_offset:
// assert var1.buffer_pool_offset + var1.size <= var2.buffer_pool_offset
// else:
// assert var2.buffer_pool_offset + var2.size <= var1.buffer_pool_offset
// // for o in egraph.computations:
// // print o.values
// print("max memory {}".format(lg.memory_footprint()))
// print("worst case memory {}".format(lg.worst_case_memory_usage()))
// print("memory efficiency {}".format(lg.memory_efficiency()))
// // // print lg.liveness_json()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment