Commit 4aa8649f authored by Robert Kimball's avatar Robert Kimball

cleanup function codegen

parent 05d42945
......@@ -1019,39 +1019,6 @@ void Emitter::EmitReshape(const ngraph::Node* n,
}
}
void Emitter::generate_call(const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs,
shared_ptr<Function> function)
{
vector<string> input_names;
vector<string> output_names;
for (const TensorViewInfo& input : inputs)
{
input_names.push_back(input.get_tensor().get_name());
}
for (const TensorViewInfo& output : outputs)
{
output_names.push_back(output.get_tensor().get_name());
}
TU << "std::vector<void*> inputs =\n{";
TU.indent++;
TU << "\n" << join(input_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "std::vector<void*> outputs =\n{";
TU.indent++;
TU << "\n" << join(output_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "\n";
TU << function->get_name() << "(inputs, outputs);\n";
}
void Emitter::EmitFunctionCall(const ngraph::Node* n,
ExternalFunction* ef,
const std::vector<TensorViewInfo>& inputs,
......@@ -1158,14 +1125,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 3\n";
TU.indent++;
string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n";
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++;
TU << "std::vector<void*> inputs;\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << "\n";
TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n";
TU << "outputs.push_back(&result);\n";
TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n";
TU.indent--;
......@@ -1204,14 +1169,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 5\n";
TU.indent++;
string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n";
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++;
TU << "std::vector<void*> inputs;\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << "\n";
TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n";
TU << "outputs.push_back(&result);\n";
TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n";
TU.indent--;
......@@ -1247,14 +1210,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 7\n";
TU.indent++;
string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n";
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++;
TU << "std::vector<void*> inputs;\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << "\n";
TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n";
TU << "outputs.push_back(&result);\n";
TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n";
TU.indent--;
......@@ -1651,3 +1612,40 @@ void Emitter::EmitAtan(const ngraph::Node* n,
TU.indent--;
TU << "}\n";
}
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
void Emitter::generate_call(const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs,
shared_ptr<Function> function)
{
vector<string> input_names;
vector<string> output_names;
for (const TensorViewInfo& input : inputs)
{
input_names.push_back(input.get_tensor().get_name());
}
for (const TensorViewInfo& output : outputs)
{
output_names.push_back(output.get_tensor().get_name());
}
TU << "std::vector<void*> inputs =\n{";
TU.indent++;
TU << "\n" << join(input_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "std::vector<void*> outputs =\n{";
TU.indent++;
TU << "\n" << join(output_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "\n";
TU << function->get_name() << "(inputs, outputs);\n";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment