Unverified Commit 36a1d96f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

use op::Constant's data rather than emitting the data in the generated cpp code.…

use op::Constant's data rather than emitting the data in the generated cpp code. This make compile times for trained models something like 100x faster. (#624)
parent 2f8b19a8
...@@ -136,38 +136,6 @@ public: ...@@ -136,38 +136,6 @@ public:
StaticInitializers() { ngraph::file_util::remove_directory(s_output_dir); } StaticInitializers() { ngraph::file_util::remove_directory(s_output_dir); }
}; };
static string emit_string_array(const vector<string>& s, size_t max_line_length)
{
stringstream ss;
stringstream line;
for (size_t i = 0; i < s.size(); i++)
{
if (i != 0)
{
line << ",";
}
stringstream value;
value << s[i];
string value_string = value.str();
if (static_cast<size_t>(line.tellp()) + value_string.size() + 1 <= max_line_length)
{
if (i > 0)
{
line << " ";
}
line << value_string;
}
else
{
ss << line.str() << "\n";
line.str("");
line << value_string;
}
}
ss << line.str();
return ss.str();
}
static StaticInitializers s_static_initializers; static StaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
...@@ -443,15 +411,11 @@ using namespace ngraph::runtime; ...@@ -443,15 +411,11 @@ using namespace ngraph::runtime;
const ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get()); const ngraph::op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get());
if (c) if (c)
{ {
m_active_constants.push_back(node);
shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view(); shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
auto c_value_strings = c->get_value_strings(); string type = tv->get_tensor().get_element_type().c_type_string();
writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " " writer << "static " << type << "* " << tv->get_tensor().get_name() << " = (("
<< tv->get_tensor().get_name() << "[" << c_value_strings.size() << "] =\n"; << type << "*)(" << c->get_data_ptr() << "));\n";
writer << "{\n";
writer.indent++;
writer << emit_string_array(c_value_strings, 100 - writer.indent * 4);
writer.indent--;
writer << "\n};\n\n";
m_variable_name_map[tv->get_tensor().get_name()] = tv->get_tensor().get_name(); m_variable_name_map[tv->get_tensor().get_name()] = tv->get_tensor().get_name();
} }
} }
......
...@@ -120,6 +120,11 @@ namespace ngraph ...@@ -120,6 +120,11 @@ namespace ngraph
bool m_use_tbb; bool m_use_tbb;
std::unordered_map<std::string, std::string> m_variable_name_map; std::unordered_map<std::string, std::string> m_variable_name_map;
// Because we are directly accessing the constant data stored in the
// Constant ops we need to keep a list of shared_ptr to each Constant
// so they don't get freed before we are done with them
std::vector<std::shared_ptr<Node>> m_active_constants;
LayoutDescriptorPtrs parameter_layout_descriptors; LayoutDescriptorPtrs parameter_layout_descriptors;
LayoutDescriptorPtrs result_layout_descriptors; LayoutDescriptorPtrs result_layout_descriptors;
std::vector<OpAttributes> m_op_attrs; std::vector<OpAttributes> m_op_attrs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment