Commit 603a7d1a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Rewrite the way constants are emitted in the CPU backend (#332)

* wip

* constants as globals

* const emitter rewrite
parent c2c33748
......@@ -573,19 +573,6 @@ void runtime::cpu::CPU_Emitter::EmitConstant(const ngraph::Node* n,
const vector<runtime::cpu::TensorViewWrapper>& args,
const vector<runtime::cpu::TensorViewWrapper>& out)
{
auto c = static_cast<const op::Constant*>(n);
auto c_value_strings = c->get_value_strings();
auto type = out[0].get_type();
m_out << "{ // " << n->get_name() << " EmitConstant\n";
m_out.indent++;
for (size_t i = 0; i < c_value_strings.size(); i++)
{
m_out << out[0].get_name() << "[" << i << "] = static_cast<" << type << ">("
<< c_value_strings[i] << ");\n";
}
m_out.indent--;
m_out << "}\n";
}
void runtime::cpu::CPU_Emitter::EmitReshape(const ngraph::Node* n,
......
......@@ -292,6 +292,28 @@ using namespace ngraph::runtime;
writer << "\n";
}
writer << "// Declare all constants\n";
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
const op::Constant* c = dynamic_cast<op::Constant*>(node.get());
if (c)
{
shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
auto c_value_strings = c->get_value_strings();
writer << tv->get_tensor().get_element_type().c_type_string() << " "
<< tv->get_tensor().get_name() << "[" << c_value_strings.size() << "] =\n";
writer << "{\n";
for (size_t i = 0; i < c_value_strings.size(); i++)
{
writer << " " << c_value_strings[i] << ",\n";
}
writer << "};\n\n";
}
}
}
writer << "// Declare all functions\n";
for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
{
......@@ -301,6 +323,22 @@ using namespace ngraph::runtime;
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
set<string> output_names;
for (const descriptor::Output* output : current_function->get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output->get_tensor_view();
output_names.insert(tv->get_tensor().get_name());
}
set<descriptor::TensorView*> constants;
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (dynamic_cast<op::Constant*>(node.get()))
{
shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
constants.insert(tv.get());
}
}
writer << "extern \"C\" void " << current_function->get_name();
writer << "(void** inputs, void** outputs)\n";
writer << "{\n";
......@@ -360,7 +398,6 @@ using namespace ngraph::runtime;
writer << "\n";
writer << "// Define outputs\n";
// create alias list
size_t output_index = 0;
unordered_map<descriptor::TensorView*, vector<size_t>> output_alias_map;
......@@ -378,7 +415,6 @@ using namespace ngraph::runtime;
}
output_index = 0;
set<string> output_names;
for (const descriptor::Output* output : current_function->get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output->get_tensor_view();
......@@ -401,28 +437,20 @@ using namespace ngraph::runtime;
}
if (!parameter_as_output && !contains(aliases, output_index))
{
string type = et.c_type_string();
writer << type << "* " << tv->get_tensor().get_name() << " = static_cast<" << type
<< "*>(outputs[" << output_index << "]);\n";
}
output_names.insert(tv->get_tensor().get_name());
output_index++;
}
writer << "\n";
writer << "// Define constants\n";
for (shared_ptr<Node> node : current_function->get_ordered_ops())
{
if (dynamic_cast<op::Constant*>(node.get()))
if (contains(constants, tv.get()))
{
shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
if (!contains(output_names, tv->get_tensor().get_name()))
writer << "memcpy(outputs[" << output_index << "], "
<< tv->get_tensor().get_name() << ", " << tv->get_tensor().size()
<< ");\n";
}
else
{
writer << tv->get_tensor().get_element_type().c_type_string() << " "
<< tv->get_tensor().get_name() << "[" << tv->get_tensor().size()
<< "];\n";
string type = et.c_type_string();
writer << type << "* " << tv->get_tensor().get_name() << " = static_cast<"
<< type << "*>(outputs[" << output_index << "]);\n";
}
}
output_index++;
}
writer << "\n";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment