Unverified Commit ffe657df authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #262 from NervanaSystems/bob/function_memory

compute temporary memory pool size per function
parents efe4ba18 47a5dcf5
...@@ -31,6 +31,7 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -31,6 +31,7 @@ Function::Function(const std::shared_ptr<Node>& result,
, m_name(name) , m_name(name)
, m_result_type(result_type) , m_result_type(result_type)
, m_ordered_ops_valid(false) , m_ordered_ops_valid(false)
, m_temporary_pool_size(0)
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
{ {
m_result->set_is_output(); m_result->set_is_output();
...@@ -97,6 +98,16 @@ void Function::set_name(const string& name) ...@@ -97,6 +98,16 @@ void Function::set_name(const string& name)
} }
} }
size_t Function::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void Function::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
std::ostream& ngraph::operator<<(std::ostream& out, const Function& f) std::ostream& ngraph::operator<<(std::ostream& out, const Function& f)
{ {
out << "Function(" << f.get_name() << ")"; out << "Function(" << f.get_name() << ")";
......
...@@ -56,6 +56,9 @@ namespace ngraph ...@@ -56,6 +56,9 @@ namespace ngraph
void clear_ordered_ops_valid() { m_ordered_ops_valid = false; } void clear_ordered_ops_valid() { m_ordered_ops_valid = false; }
friend std::ostream& operator<<(std::ostream&, const Function&); friend std::ostream& operator<<(std::ostream&, const Function&);
size_t get_instance_id() { return m_instance_id; } size_t get_instance_id() { return m_instance_id; }
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
...@@ -64,6 +67,7 @@ namespace ngraph ...@@ -64,6 +67,7 @@ namespace ngraph
bool m_ordered_ops_valid; bool m_ordered_ops_valid;
std::list<std::shared_ptr<Node>> m_ordered_ops; std::list<std::shared_ptr<Node>> m_ordered_ops;
std::list<std::shared_ptr<Node>> m_ops; std::list<std::shared_ptr<Node>> m_ops;
size_t m_temporary_pool_size;
private: private:
Function(const Function&) = delete; Function(const Function&) = delete;
......
...@@ -27,13 +27,3 @@ const vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions() ...@@ -27,13 +27,3 @@ const vector<shared_ptr<Function>>& ngraph::pass::ManagerState::get_functions()
{ {
return m_function_list; return m_function_list;
} }
size_t ngraph::pass::ManagerState::get_temporary_pool_size()
{
return m_temporary_pool_size;
}
void ngraph::pass::ManagerState::set_temporary_pool_size(size_t size)
{
m_temporary_pool_size = size;
}
...@@ -39,10 +39,6 @@ public: ...@@ -39,10 +39,6 @@ public:
m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end()); m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end());
} }
size_t get_temporary_pool_size();
void set_temporary_pool_size(size_t);
private: private:
size_t m_temporary_pool_size = 0;
std::vector<std::shared_ptr<Function>> m_function_list; std::vector<std::shared_ptr<Function>> m_function_list;
}; };
...@@ -31,10 +31,10 @@ pass::MemoryLayout::MemoryLayout(size_t alignment) ...@@ -31,10 +31,10 @@ pass::MemoryLayout::MemoryLayout(size_t alignment)
{ {
} }
bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& node_list) bool pass::MemoryLayout::run_on_function(std::shared_ptr<ngraph::Function> function)
{ {
MemoryManager mm(m_alignment); MemoryManager mm(m_alignment);
for (shared_ptr<Node> node : node_list) for (shared_ptr<Node> node : function->get_ordered_ops())
{ {
for (Tensor* tensor : node->liveness_new_list) for (Tensor* tensor : node->liveness_new_list)
{ {
...@@ -46,7 +46,7 @@ bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& nod ...@@ -46,7 +46,7 @@ bool pass::MemoryLayout::run_on_call_graph(std::list<std::shared_ptr<Node>>& nod
mm.free(tensor->get_pool_offset()); mm.free(tensor->get_pool_offset());
} }
} }
get_state().set_temporary_pool_size(mm.max_allocated()); function->set_temporary_pool_size(mm.max_allocated());
return false; return false;
} }
......
...@@ -30,11 +30,11 @@ namespace ngraph ...@@ -30,11 +30,11 @@ namespace ngraph
} }
} }
class ngraph::pass::MemoryLayout : public CallGraphPass class ngraph::pass::MemoryLayout : public FunctionPass
{ {
public: public:
MemoryLayout(size_t alignment = 1); MemoryLayout(size_t alignment = 1);
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override; bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private: private:
size_t m_alignment; size_t m_alignment;
......
...@@ -223,7 +223,7 @@ using namespace ngraph::runtime::cpu::eigen; ...@@ -223,7 +223,7 @@ using namespace ngraph::runtime::cpu::eigen;
} }
if (temporaries_used) if (temporaries_used)
{ {
size_t temp_pool_size = pass_manager.get_state().get_temporary_pool_size(); size_t temp_pool_size = current_function->get_temporary_pool_size();
TU << "// Allocate the memory pool\n"; TU << "// Allocate the memory pool\n";
TU << "ngraph::runtime::AlignedBuffer memory_handler(" << temp_pool_size << ", " TU << "ngraph::runtime::AlignedBuffer memory_handler(" << temp_pool_size << ", "
<< ngraph::runtime::cpu::alignment << ");\n"; << ngraph::runtime::cpu::alignment << ");\n";
......
...@@ -215,7 +215,7 @@ TEST(memory_layout, basic) ...@@ -215,7 +215,7 @@ TEST(memory_layout, basic)
auto graph = make_test_graph(); auto graph = make_test_graph();
pass_manager.run_passes(graph); pass_manager.run_passes(graph);
auto sorted = graph->get_ordered_ops(); auto sorted = graph->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size(); size_t temporary_pool_size = graph->get_temporary_pool_size();
EXPECT_EQ(12, temporary_pool_size); EXPECT_EQ(12, temporary_pool_size);
} }
...@@ -235,6 +235,6 @@ TEST(memory_layout, constant) ...@@ -235,6 +235,6 @@ TEST(memory_layout, constant)
pass_manager.run_passes(f); pass_manager.run_passes(f);
auto sorted = f->get_ordered_ops(); auto sorted = f->get_ordered_ops();
size_t temporary_pool_size = pass_manager.get_state().get_temporary_pool_size(); size_t temporary_pool_size = f->get_temporary_pool_size();
EXPECT_EQ(0, temporary_pool_size); EXPECT_EQ(0, temporary_pool_size);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment