Commit 4aa8649f authored by Robert Kimball's avatar Robert Kimball

cleanup function codegen

parent 05d42945
...@@ -1019,39 +1019,6 @@ void Emitter::EmitReshape(const ngraph::Node* n, ...@@ -1019,39 +1019,6 @@ void Emitter::EmitReshape(const ngraph::Node* n,
} }
} }
void Emitter::generate_call(const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs,
shared_ptr<Function> function)
{
vector<string> input_names;
vector<string> output_names;
for (const TensorViewInfo& input : inputs)
{
input_names.push_back(input.get_tensor().get_name());
}
for (const TensorViewInfo& output : outputs)
{
output_names.push_back(output.get_tensor().get_name());
}
TU << "std::vector<void*> inputs =\n{";
TU.indent++;
TU << "\n" << join(input_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "std::vector<void*> outputs =\n{";
TU.indent++;
TU << "\n" << join(output_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "\n";
TU << function->get_name() << "(inputs, outputs);\n";
}
void Emitter::EmitFunctionCall(const ngraph::Node* n, void Emitter::EmitFunctionCall(const ngraph::Node* n,
ExternalFunction* ef, ExternalFunction* ef,
const std::vector<TensorViewInfo>& inputs, const std::vector<TensorViewInfo>& inputs,
...@@ -1158,14 +1125,12 @@ void Emitter::EmitReduce(const ngraph::Node* n, ...@@ -1158,14 +1125,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 3\n"; TU << "{ // " << n->get_name() << " 3\n";
TU.indent++; TU.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n"; TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++; TU.indent++;
TU << "std::vector<void*> inputs;\n"; TU << "\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << type << " result;\n"; TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n"; TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "outputs.push_back(&result);\n"; TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n"; TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n"; TU << "return result;\n";
TU.indent--; TU.indent--;
...@@ -1204,14 +1169,12 @@ void Emitter::EmitReduce(const ngraph::Node* n, ...@@ -1204,14 +1169,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 5\n"; TU << "{ // " << n->get_name() << " 5\n";
TU.indent++; TU.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n"; TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++; TU.indent++;
TU << "std::vector<void*> inputs;\n"; TU << "\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << type << " result;\n"; TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n"; TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "outputs.push_back(&result);\n"; TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n"; TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n"; TU << "return result;\n";
TU.indent--; TU.indent--;
...@@ -1247,14 +1210,12 @@ void Emitter::EmitReduce(const ngraph::Node* n, ...@@ -1247,14 +1210,12 @@ void Emitter::EmitReduce(const ngraph::Node* n,
TU << "{ // " << n->get_name() << " 7\n"; TU << "{ // " << n->get_name() << " 7\n";
TU.indent++; TU.indent++;
string type = f_result_element_type.c_type_string(); string type = f_result_element_type.c_type_string();
TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << " {\n"; TU << "auto f = [](" << type << " x, " << type << " y) -> " << type << "\n{";
TU.indent++; TU.indent++;
TU << "std::vector<void*> inputs;\n"; TU << "\n";
TU << "inputs.push_back(&x);\n";
TU << "inputs.push_back(&y);\n\n";
TU << type << " result;\n"; TU << type << " result;\n";
TU << "std::vector<void*> outputs;\n"; TU << "std::vector<void*> inputs = {&x, &y};\n";
TU << "outputs.push_back(&result);\n"; TU << "std::vector<void*> outputs = {&result};\n";
TU << reduction_function->get_name() << "(inputs, outputs);\n"; TU << reduction_function->get_name() << "(inputs, outputs);\n";
TU << "return result;\n"; TU << "return result;\n";
TU.indent--; TU.indent--;
...@@ -1651,3 +1612,40 @@ void Emitter::EmitAtan(const ngraph::Node* n, ...@@ -1651,3 +1612,40 @@ void Emitter::EmitAtan(const ngraph::Node* n,
TU.indent--; TU.indent--;
TU << "}\n"; TU << "}\n";
} }
//------------------------------------------------------------------------------------------------
// Utility methods
//------------------------------------------------------------------------------------------------
void Emitter::generate_call(const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs,
shared_ptr<Function> function)
{
vector<string> input_names;
vector<string> output_names;
for (const TensorViewInfo& input : inputs)
{
input_names.push_back(input.get_tensor().get_name());
}
for (const TensorViewInfo& output : outputs)
{
output_names.push_back(output.get_tensor().get_name());
}
TU << "std::vector<void*> inputs =\n{";
TU.indent++;
TU << "\n" << join(input_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "std::vector<void*> outputs =\n{";
TU.indent++;
TU << "\n" << join(output_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "\n";
TU << function->get_name() << "(inputs, outputs);\n";
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment