Unverified Commit ffe657df authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #262 from NervanaSystems/bob/function_memory

compute temporary memory pool size per function
parents efe4ba18 47a5dcf5
......@@ -31,6 +31,7 @@ Function::Function(const std::shared_ptr<Node>& result,
, m_name(name)
, m_result_type(result_type)
, m_ordered_ops_valid(false)
, m_temporary_pool_size(0)
, m_instance_id(m_next_instance_id.fetch_add(1))
{
m_result->set_is_output();
......@@ -97,6 +98,16 @@ void Function::set_name(const string& name)
}
}
size_t Function::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void Function::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
std::ostream& ngraph::operator<<(std::ostream& out, const Function& f)
{
out << "Function(" << f.get_name() << ")";
......
......@@ -56,6 +56,9 @@ namespace ngraph
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
protected:
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
......@@ -64,6 +67,7 @@ namespace ngraph
bool m_ordered_ops_valid;
std::list<std::shared_ptr<Node>> m_ordered_ops;
std::list<std::shared_ptr<Node>> m_ops;
size_t m_temporary_pool_size;
private:
Function(const Function&) = delete;
......
......@@ -27,13 +27,3 @@ const vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions()
{
return m_function_list;
}
size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
......@@ -39,10 +39,6 @@ public:
m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end());
}
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
private:
size_t m_temporary_pool_size = 0;
std::vector<std::shared_ptr<Function>> m_function_list;
};
......@@ -31,10 +31,10 @@ pass::MemoryLayout::MemoryLayout(size_t alignment)
{
}
bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& node_list)
bool pass::MemoryLayout::run_on_function(std::shared_ptr<ngraph::Function> function)
{
MemoryManager mm(m_alignment);
for (shared_ptr<Node> node : node_list)
for (shared_ptr<Node> node : function->get_ordered_ops())
{
for (Tensor* tensor : node->liveness_new_list)
{
......@@ -46,7 +46,7 @@ bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& nod
mm.free(tensor->get_pool_offset());
}
}
get_state().set_temporary_pool_size(mm.max_allocated());
function->set_temporary_pool_size(mm.max_allocated());
return false;
}
......
......@@ -30,11 +30,11 @@ namespace ngraph
}
}
class ngraph::pass::MemoryLayout : public CallGraphPass
class ngraph::pass::MemoryLayout : public FunctionPass
{
public:
MemoryLayout(size_t alignment = 1);
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
size_t m_alignment;
......
......@@ -223,7 +223,7 @@ using namespace ngraph::runtime::cpu::eigen;
}
if (temporaries_used)
{
size_t temp_pool_size = pass_manager.get_state().get_temporary_pool_size();
size_t temp_pool_size = current_function->get_temporary_pool_size();
TU << "// Allocate the memory pool\n";
TU << "ngraph::runtime::AlignedBuffer memory_handler(" << temp_pool_size << ", "
<< ngraph::runtime::cpu::alignment << ");\n";
......
......@@ -215,7 +215,7 @@ TEST(memory_layout, basic)
auto graph = make_test_graph();
pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size();
size_t temporary_pool_size = graph->get_temporary_pool_size();
EXPECT_EQ(12, temporary_pool_size);
}
......@@ -235,6 +235,6 @@ TEST(memory_layout, constant)
pass_manager.run_passes(f);
auto sorted = f->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size();
size_t temporary_pool_size = f->get_temporary_pool_size();
EXPECT_EQ(0, temporary_pool_size);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment