Commit 38c70648 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Cache the order ops during codegen for a nice compile speedup (#800)

parent 86f88126
...@@ -313,6 +313,12 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -313,6 +313,12 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment); pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
unordered_map<shared_ptr<Function>, list<shared_ptr<Node>>> function_ordered_ops;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
function_ordered_ops.insert({current_function, current_function->get_ordered_ops()});
}
codegen::CodeWriter writer; codegen::CodeWriter writer;
writer += writer +=
...@@ -380,7 +386,7 @@ using namespace ngraph::runtime; ...@@ -380,7 +386,7 @@ using namespace ngraph::runtime;
size_t index = 0; size_t index = 0;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions()) for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{ {
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : function_ordered_ops.at(current_function))
{ {
if (!node->is_parameter() && !node->is_constant()) if (!node->is_parameter() && !node->is_constant())
{ {
...@@ -430,7 +436,7 @@ using namespace ngraph::runtime; ...@@ -430,7 +436,7 @@ using namespace ngraph::runtime;
writer << "// Declare all constants\n"; writer << "// Declare all constants\n";
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions()) for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{ {
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : function_ordered_ops.at(current_function))
{ {
const ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get()); const ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get());
if (c) if (c)
...@@ -459,7 +465,7 @@ using namespace ngraph::runtime; ...@@ -459,7 +465,7 @@ using namespace ngraph::runtime;
unordered_map<Node*, string> match_functions; unordered_map<Node*, string> match_functions;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions()) for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{ {
const list<shared_ptr<Node>>& tmp = current_function->get_ordered_ops(); list<shared_ptr<Node>> tmp = function_ordered_ops.at(current_function);
if (tmp.size() < 2) if (tmp.size() < 2)
{ {
// Since we are comparing ops there must be at least two ops to proceed. // Since we are comparing ops there must be at least two ops to proceed.
...@@ -518,6 +524,7 @@ using namespace ngraph::runtime; ...@@ -518,6 +524,7 @@ using namespace ngraph::runtime;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions()) for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{ {
auto ordered_ops = function_ordered_ops.at(current_function);
set<string> output_names; set<string> output_names;
for (shared_ptr<Node> op : current_function->get_results()) for (shared_ptr<Node> op : current_function->get_results())
{ {
...@@ -525,7 +532,7 @@ using namespace ngraph::runtime; ...@@ -525,7 +532,7 @@ using namespace ngraph::runtime;
output_names.insert(tv->get_tensor().get_name()); output_names.insert(tv->get_tensor().get_name());
} }
set<descriptor::TensorView*> constants; set<descriptor::TensorView*> constants;
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : ordered_ops)
{ {
if (dynamic_cast<ngraph::op::Constant*>(node.get())) if (dynamic_cast<ngraph::op::Constant*>(node.get()))
{ {
...@@ -554,7 +561,7 @@ using namespace ngraph::runtime; ...@@ -554,7 +561,7 @@ using namespace ngraph::runtime;
bool temporaries_used = false; bool temporaries_used = false;
size_t worst_case_tmp_size = 0; size_t worst_case_tmp_size = 0;
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : ordered_ops)
{ {
if (node->liveness_new_list.size() > 0) if (node->liveness_new_list.size() > 0)
{ {
...@@ -577,7 +584,7 @@ using namespace ngraph::runtime; ...@@ -577,7 +584,7 @@ using namespace ngraph::runtime;
writer << "\n"; writer << "\n";
// Add temporaries to the variable name map // Add temporaries to the variable name map
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : ordered_ops)
{ {
for (descriptor::Tensor* tensor : node->liveness_new_list) for (descriptor::Tensor* tensor : node->liveness_new_list)
{ {
...@@ -645,7 +652,7 @@ using namespace ngraph::runtime; ...@@ -645,7 +652,7 @@ using namespace ngraph::runtime;
} }
} }
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : ordered_ops)
{ {
auto& n = *node; // Work around a compiler warning (*node inside typeid may have effects auto& n = *node; // Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.) // with shared pointers, which is fine here but clang doesn't like it.)
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <iostream> #include <iostream>
using namespace std; using namespace std;
using namespace ngraph;
std::string ngraph::to_cplusplus_sourcecode_literal(bool val) std::string ngraph::to_cplusplus_sourcecode_literal(bool val)
{ {
...@@ -280,6 +281,63 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -280,6 +281,63 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
return fprop_cache; return fprop_cache;
} }
size_t stopwatch::get_call_count() const
{
return m_total_count;
}
size_t stopwatch::get_seconds() const
{
return chrono::duration_cast<chrono::seconds>(get_timer_value()).count();
}
size_t stopwatch::get_milliseconds() const
{
return chrono::duration_cast<chrono::milliseconds>(get_timer_value()).count();
}
size_t stopwatch::get_microseconds() const
{
return chrono::duration_cast<chrono::microseconds>(get_timer_value()).count();
}
size_t stopwatch::get_nanoseconds() const
{
return get_timer_value().count();
}
chrono::nanoseconds stopwatch::get_timer_value() const
{
if (m_active)
{
return (m_clock.now() - m_start_time);
}
else
{
return m_last_time;
}
}
size_t stopwatch::get_total_seconds() const
{
return chrono::duration_cast<chrono::seconds>(m_total_time).count();
}
size_t stopwatch::get_total_milliseconds() const
{
return chrono::duration_cast<chrono::milliseconds>(m_total_time).count();
}
size_t stopwatch::get_total_microseconds() const
{
return chrono::duration_cast<chrono::microseconds>(m_total_time).count();
}
size_t stopwatch::get_total_nanoseconds() const
{
return m_total_time.count();
}
namespace ngraph namespace ngraph
{ {
template <> template <>
......
...@@ -135,33 +135,25 @@ namespace ngraph ...@@ -135,33 +135,25 @@ namespace ngraph
} }
} }
size_t get_call_count() const { return m_total_count; } size_t get_call_count() const;
size_t get_seconds() const { return get_nanoseconds() / 1e9; } size_t get_seconds() const;
size_t get_milliseconds() const { return get_nanoseconds() / 1e6; } size_t get_milliseconds() const;
size_t get_microseconds() const { return get_nanoseconds() / 1e3; } size_t get_microseconds() const;
size_t get_nanoseconds() const std::chrono::nanoseconds get_timer_value() const;
{ size_t get_nanoseconds() const;
if (m_active)
{ size_t get_total_seconds() const;
return (m_clock.now() - m_start_time).count(); size_t get_total_milliseconds() const;
} size_t get_total_microseconds() const;
else size_t get_total_nanoseconds() const;
{
return m_last_time.count();
}
}
size_t get_total_seconds() const { return get_total_nanoseconds() / 1e9; }
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); }
private: private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
bool m_active = false; bool m_active = false;
std::chrono::nanoseconds m_total_time = std::chrono::nanoseconds m_total_time =
std::chrono::high_resolution_clock::duration::zero(); std::chrono::high_resolution_clock::duration::zero();
std::chrono::nanoseconds m_last_time; std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero();
size_t m_total_count = 0; size_t m_total_count = 0;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment