Commit 7c8e9250 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

gpu function call (#1111)

* enable tests

* add funciton call

* working version

* remove test from ski list
parent 3d66cba4
...@@ -996,6 +996,45 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -996,6 +996,45 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end(); writer.block_end();
} }
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
auto function_call = static_cast<const ngraph::op::FunctionCall*>(node);
shared_ptr<Function> function = function_call->get_functions()[0];
writer.block_begin();
{
std::vector<string> input_names;
std::vector<string> output_names;
for (const runtime::gpu::GPU_TensorViewWrapper& input : args)
{
input_names.push_back(input.get_name());
}
for (const runtime::gpu::GPU_TensorViewWrapper& output : out)
{
output_names.push_back(output.get_name());
}
writer << "void* args[] =\n";
writer.block_begin();
writer << "\n" << join(input_names, ",\n");
writer.block_end();
writer << ";\n";
writer << "void* out[] =\n";
writer.block_begin();
writer << "\n" << join(output_names, ",\n");
writer.block_end();
writer << ";\n";
writer << "\n";
writer << function->get_name() << "(args, out, ctx);\n";
}
writer.block_end();
}
template <> template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Slice) void GPU_Emitter::EMITTER_DECL(ngraph::op::Slice)
{ {
...@@ -1112,11 +1151,6 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -1112,11 +1151,6 @@ CUDNN_SAFE_CALL(cudnnSetOpTensorDescriptor(opTensorDesc,
writer.block_end(); writer.block_end();
} }
template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::FunctionCall)
{
}
template <> template <>
void GPU_Emitter::EMITTER_DECL(ngraph::op::Multiply) void GPU_Emitter::EMITTER_DECL(ngraph::op::Multiply)
{ {
......
...@@ -21,7 +21,6 @@ divide_by_zero_float32 ...@@ -21,7 +21,6 @@ divide_by_zero_float32
divide_by_zero_int32 divide_by_zero_int32
dot_4d_5d_multi_axis_big_fp64_VERY_SLOW dot_4d_5d_multi_axis_big_fp64_VERY_SLOW
dot_matrix_vector_int64 dot_matrix_vector_int64
function_call
mkldnn_layouts mkldnn_layouts
numeric_double_nan numeric_double_nan
numeric_float_inf numeric_float_inf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment