Commit 05d42945 authored by Robert Kimball's avatar Robert Kimball

add utility method to emit function call

parent 24a7db84
...@@ -1019,29 +1019,50 @@ void Emitter::EmitReshape(const ngraph::Node* n, ...@@ -1019,29 +1019,50 @@ void Emitter::EmitReshape(const ngraph::Node* n,
} }
} }
void Emitter::EmitFunctionCall(const ngraph::Node* n, void Emitter::generate_call(const std::vector<TensorViewInfo>& inputs,
ExternalFunction* ef, const std::vector<TensorViewInfo>& outputs,
const std::vector<TensorViewInfo>& inputs, shared_ptr<Function> function)
const std::vector<TensorViewInfo>& outputs)
{ {
auto function_call = static_cast<const op::FunctionCall*>(n); vector<string> input_names;
shared_ptr<Function> function = function_call->get_function(); vector<string> output_names;
TU << "{ // Call " << function->get_name() << "\n";
TU.indent++;
TU << "std::vector<void*> inputs;\n";
for (const TensorViewInfo& input : inputs) for (const TensorViewInfo& input : inputs)
{ {
TU << "inputs.push_back(" << input.get_tensor().get_name() << ");\n"; input_names.push_back(input.get_tensor().get_name());
} }
TU << "\n";
TU << "std::vector<void*> outputs;\n";
for (const TensorViewInfo& output : outputs) for (const TensorViewInfo& output : outputs)
{ {
TU << "outputs.push_back(" << output.get_tensor().get_name() << ");\n"; output_names.push_back(output.get_tensor().get_name());
} }
TU << "std::vector<void*> inputs =\n{";
TU.indent++;
TU << "\n" << join(input_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "std::vector<void*> outputs =\n{";
TU.indent++;
TU << "\n" << join(output_names, ",\n");
TU.indent--;
TU << "\n};\n";
TU << "\n"; TU << "\n";
TU << function->get_name() << "(inputs, outputs);\n"; TU << function->get_name() << "(inputs, outputs);\n";
}
void Emitter::EmitFunctionCall(const ngraph::Node* n,
ExternalFunction* ef,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
{
auto function_call = static_cast<const op::FunctionCall*>(n);
shared_ptr<Function> function = function_call->get_function();
TU << "{ // Call " << function->get_name() << "\n";
TU.indent++;
generate_call(inputs, outputs, function);
TU.indent--; TU.indent--;
TU << "}\n"; TU << "}\n";
} }
......
...@@ -94,6 +94,11 @@ namespace ngraph ...@@ -94,6 +94,11 @@ namespace ngraph
void EMITTER_DECL(EmitAsin); void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitAcos); void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitAtan); void EMITTER_DECL(EmitAtan);
private:
void generate_call(const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs,
std::shared_ptr<Function> function);
}; };
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment