Commit c46d4546 authored by Fenglei's avatar Fenglei Committed by Robert Kimball

gpu replace slice optimize (#1411)

* optimize replace slice

* fix bugs

* fix bug

* optimize pad dynamic

* fix bug

* fix bug

* fix bug

* remove *

* add gpu_assignment to pass

* refactor cuda replace slice.

* fix bug

* refactor replace slice

* working version

* clang format

* us layout instead of assignment

* us layout instead of assignment in cmakelist

* update gpu_layout

* fix bugs

* resolve conflict

* GPUShape to NVShape

* using kernel args

* using kernel args

* fix bugs

* fix bugs

* fix bug, remove mkldnn.h from gpu_layout.cpp

* fix bug for pad_below

* remove cast to rep_slice

* fix bugs

* clang format

* change add_in_place_oi_pair({0, 0, false} to add_in_place_oi_pair({0, 0, true};
parent 7d3323c9
......@@ -38,6 +38,7 @@ set(SRC
gpu_tensor_view.cpp
gpu_util.cpp
type_info.cpp
pass/gpu_layout.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
)
......
......@@ -483,8 +483,9 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string
NVShape padding_below,
NVShape padding_interior)
{
uint32_t rank = static_cast<uint32_t>(input_shape.size());
std::stringstream kernel_name;
kernel_name << "pad_dynamic_" << join(dtypes, "_");
kernel_name << "pad_dynamic_" << join(dtypes, "_") << rank;
std::string hash = kernel_name.str() + "pad_i" + join(input_shape, "_") + "pad_o" +
join(output_shape) + "_pb" + join(padding_below, "_") + "_pi" +
......@@ -500,21 +501,11 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string
return primitive_index;
}
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_pad_dynamic_op(writer, kernel_name.str(), dtypes);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
uint32_t rank = static_cast<uint32_t>(input_shape.size());
uint32_t nthreads = static_cast<uint32_t>(shape_size(input_shape));
//TODO: currently we set it to 64, will add tuning method later
uint32_t block_size_x = 64;
uint32_t aligned_grid_size_x = align_to_block_size(nthreads, block_size_x);
NVShape pad_below(input_shape.size(), 0);
NVShape pad_interior(input_shape.size(), 1);
......@@ -529,48 +520,48 @@ size_t runtime::gpu::CUDAEmitter::build_pad_dynamic(const std::array<std::string
NVShape input_strides = row_major_strides(input_shape);
NVShape output_strides = row_major_strides(output_shape);
// get an allocator for transient per kernel gpu memory
GPUAllocator allocator = this->m_primitive_emitter->get_memory_allocator();
size_t idx_input_strides =
allocator.reserve_argspace(input_strides.data(), input_strides.size() * sizeof(uint32_t));
size_t idx_output_strides =
allocator.reserve_argspace(output_strides.data(), output_strides.size() * sizeof(uint32_t));
size_t idx_padding_below =
allocator.reserve_argspace(pad_below.data(), pad_below.size() * sizeof(uint32_t));
size_t idx_padding_interior =
allocator.reserve_argspace(pad_interior.data(), pad_interior.size() * sizeof(uint32_t));
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "out")
.add("input_strides", input_strides)
.add("output_strides", output_strides)
.add("padding_below", pad_below)
.add("padding_interior", pad_interior)
.add("n", nthreads);
// check if the kernel has already been compiled. if so, create
// a launch primitive for it based on the input tensor shape
// but do not recompile the kernel. otherwise, do it all:
// recompile the kernel and then create the primitive
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name.str());
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
CudaKernelBuilder::add_pod_typedefs(writer);
CudaKernelBuilder::get_pad_dynamic_op(writer, kernel_name.str(), args, dtypes, rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name.str(), writer.get_code());
}
// create the launch primitive
std::unique_ptr<gpu::primitive> pad_dynamic(new gpu::primitive{[=](void** inputs,
void** outputs) mutable {
void* param_input_strides = runtime::gpu::invoke_memory_primitive(m_ctx, idx_input_strides);
void* param_output_strides =
runtime::gpu::invoke_memory_primitive(m_ctx, idx_output_strides);
void* param_padding_below = runtime::gpu::invoke_memory_primitive(m_ctx, idx_padding_below);
void* param_padding_interior =
runtime::gpu::invoke_memory_primitive(m_ctx, idx_padding_interior);
std::vector<void*> args_list{&inputs[0],
&outputs[0],
&param_input_strides,
&param_output_strides,
&param_padding_below,
&param_padding_interior,
&rank,
&nthreads};
std::unique_ptr<gpu::primitive> pad_dynamic(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
static_cast<uint32_t>(nthreads),
1,
1, // grid dim
1,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list.data(),
0)); // arguments
debug_sync();
}});
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
aligned_grid_size_x,
1,
1, // grid dim
block_size_x,
1,
1, // block dim
0,
NULL, // shared mem and stream
args_list,
0)); // arguments
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(pad_dynamic));
m_primitive_emitter->cache(hash, primitive_index);
......@@ -1653,113 +1644,6 @@ size_t runtime::gpu::CUDAEmitter::build_reduce_window(const OpName op_name,
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_replace_slice(const std::array<std::string, 3>& dtypes,
NVShape tensor_shape,
NVShape source_shape,
NVShape lower_bounds,
NVShape upper_bounds,
NVShape slice_strides)
{
// assumes NC{d1,...,dn} format
std::string kernel_name =
"repslices_" + join(dtypes, "_") + "_r" + std::to_string(tensor_shape.size());
std::replace(kernel_name.begin(), kernel_name.end(), ' ', '_');
std::stringstream ss;
ss << kernel_name << "_s" << join(tensor_shape, "_") << "_ssrc" << join(source_shape, "_")
<< "_sll" << join(lower_bounds, "_") << "_slu" << join(upper_bounds, "_") << "_slst"
<< join(slice_strides, "_");
auto hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// calculate strides
NVShape input_strides = row_major_strides(tensor_shape);
NVShape source_strides = row_major_strides(source_shape);
// precacluate invariants for integer division via multiplication
std::vector<int> dmagics;
std::vector<int> dshifts;
std::vector<int> smagics;
std::vector<int> sshifts;
for (int i = 0; i < tensor_shape.size(); i++)
{
int magic;
int shift;
std::tie(magic, shift) = idiv_magic_u64(input_strides[i]);
dmagics.push_back(magic);
dshifts.push_back(shift);
std::tie(magic, shift) = idiv_magic_u64(slice_strides[i]);
smagics.push_back(magic);
sshifts.push_back(shift);
}
size_t rank = tensor_shape.size();
size_t nthreads = shape_size(tensor_shape);
constexpr const int nthreads_per_block = 32;
int nblocks = 1 + ((static_cast<int>(nthreads) - 1) / nthreads_per_block); // ceil_div(nthreads)
// TODO: blending factors are not currently implemented
float alpha = 1.0f;
float beta = 0.0f;
auto args = m_primitive_emitter->add_kernel_args();
args.add_placeholder(dtypes[0], "in")
.add_placeholder(dtypes[1], "source")
.add_placeholder(dtypes[2], "out")
.add("alpha", alpha)
.add("beta", beta)
.add("dim_strides", input_strides)
.add("dim_magic", dmagics)
.add("dim_shift", dshifts)
.add("lower_bounds", lower_bounds)
.add("upper_bounds", upper_bounds)
.add("slice_str", slice_strides)
.add("slice_magic", smagics)
.add("slice_shift", sshifts)
.add("src_strides", source_strides)
.add("nthreads", nthreads);
// if the kernel has not been compiled, build it
auto compiled_kernel = m_ctx->compiled_kernel_pool->get(kernel_name);
if (compiled_kernel == nullptr)
{
codegen::CodeWriter writer;
writer << include_helpers();
runtime::gpu::CudaKernelBuilder::get_replace_slice_op(writer, kernel_name, args, rank);
compiled_kernel = m_ctx->compiled_kernel_pool->set(kernel_name, writer.get_code());
}
std::unique_ptr<gpu::primitive> replace_slice(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void** args_list = args.resolve_placeholder(0, &inputs[0])
.resolve_placeholder(1, &inputs[1])
.resolve_placeholder(2, &outputs[0])
.get_argument_list();
CUDA_SAFE_CALL(cuLaunchKernel(*compiled_kernel.get(),
nblocks,
1,
1,
nthreads_per_block,
1,
1,
0,
NULL,
args_list,
0));
debug_sync();
}});
primitive_index = this->m_primitive_emitter->insert(std::move(replace_slice));
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_broadcast(const std::array<std::string, 2>& dtypes,
NVShape result_shape,
const std::set<size_t>& reduce_axes)
......@@ -1989,6 +1873,80 @@ size_t runtime::gpu::CUDAEmitter::build_primitive(const op::Convolution* node)
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_primitive(const op::ReplaceSlice* node, bool in_place_op)
{
auto& args = node->get_inputs();
auto& out = node->get_outputs();
auto& input_shape = args[0].get_shape();
auto& replace_shape = args[1].get_shape();
auto& lower_bounds = node->get_lower_bounds();
auto& upper_bounds = node->get_upper_bounds();
auto& slice_strides = node->get_strides();
Shape slice_shape(upper_bounds.size(), 0);
std::transform(upper_bounds.begin(),
upper_bounds.end(),
lower_bounds.begin(),
slice_shape.begin(),
std::minus<size_t>());
std::transform(slice_shape.begin(),
slice_shape.end(),
slice_strides.begin(),
slice_shape.begin(),
std::divides<size_t>());
auto input_type = args[0].get_element_type().c_type_string();
auto replace_type = args[1].get_element_type().c_type_string();
auto output_type = out[0].get_element_type().c_type_string();
// assumes NC{d1,...,dn} format
std::string type_str = input_type + "_" + replace_type + "_" + output_type;
std::replace(type_str.begin(), type_str.end(), ' ', '_');
std::stringstream ss;
ss << "rep_slices_" << type_str << "_s" << join(input_shape, "_") << "_ssrc"
<< join(replace_shape, "_") << "_sll" << join(lower_bounds, "_") << "_slu"
<< join(upper_bounds, "_") << "_slst" << join(slice_strides, "_") << in_place_op;
auto hash = ss.str();
// check if the requested primtive is already built
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
// calculate strides
Shape input_strides = row_major_strides(input_shape);
Shape replace_strides = row_major_strides(replace_shape);
size_t pad_index = build_pad_dynamic(
{{input_type, output_type}}, replace_shape, input_shape, lower_bounds, slice_strides);
if (in_place_op)
{
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
runtime::gpu::invoke_primitive(
m_ctx, pad_index, std::vector<void*>{inputs[1]}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
else
{
size_t nthreads = shape_size(input_shape);
size_t size = nthreads * args[1].get_element_type().size();
std::unique_ptr<gpu::primitive> kernel_launch(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
runtime::gpu::cuda_memcpyDtD(outputs[0], inputs[0], size);
runtime::gpu::invoke_primitive(
m_ctx, pad_index, std::vector<void*>{inputs[1]}.data(), outputs);
}});
primitive_index = this->m_primitive_emitter->insert(std::move(kernel_launch));
}
m_primitive_emitter->cache(hash, primitive_index);
return primitive_index;
}
size_t runtime::gpu::CUDAEmitter::build_convolution(const std::array<std::string, 3>& dtypes,
NVShape input_shape,
NVShape filter_shape,
......
......@@ -25,6 +25,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/softmax.hpp"
namespace ngraph
......@@ -46,6 +47,7 @@ namespace ngraph
size_t build_primitive(const op::Softmax* node);
size_t build_primitive(const op::Convolution* node);
size_t build_primitive(const op::MaxPool* node);
size_t build_primitive(const op::ReplaceSlice* node, bool in_place_op);
public:
size_t build_pad(const std::array<std::string, 2>& dtypes,
......@@ -130,13 +132,6 @@ namespace ngraph
save_elementwise);
}
size_t build_replace_slice(const std::array<std::string, 3>& dtypes,
NVShape tensor_shape,
NVShape source_shape,
NVShape lower_bounds,
NVShape upper_bounds,
NVShape slice_stride);
size_t build_broadcast(const std::array<std::string, 2>& dtypes,
NVShape result_shape,
const std::set<size_t>& bcast_axes);
......
......@@ -306,12 +306,11 @@ void runtime::gpu::CudaKernelBuilder::get_concat_op(codegen::CodeWriter& writer,
void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types)
GPUKernelArgs& args,
const std::array<std::string, 2>& data_types,
size_t rank)
{
writer << "extern \"C\" __global__ void cuda_" << name << "(" << data_types[0] << "* in, "
<< data_types[1] << "* out, uint32_t* input_strides, uint32_t* output_strides, "
"uint32_t* padding_below, uint32_t* "
"padding_interior, uint32_t rank, uint32_t n)\n";
writer << "extern \"C\" __global__ void cuda_" << name << args.get_input_signature();
writer.block_begin();
{
writer << "uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;\n";
......@@ -319,17 +318,19 @@ void runtime::gpu::CudaKernelBuilder::get_pad_dynamic_op(
writer.block_begin();
{
writer << "uint32_t output_idx = 0;\n";
writer << "uint32_t input_idx = tid;\n";
writer << "for(uint32_t i = 0; i < rank; i++)\n";
writer.block_begin();
if (rank > 0)
{
writer << "output_idx += (input_idx / input_strides[i] * padding_interior[i] + "
"padding_below[i]) "
"* output_strides[i];\n";
writer << "input_idx %= input_strides[i];\n";
writer << "uint32_t input_idx = tid;\n";
}
for (size_t i = 0; i < rank; i++)
{
writer << "output_idx += (input_idx / input_strides" << i << " * padding_interior"
<< i << " + "
"padding_below"
<< i << ") * output_strides" << i << ";\n";
writer << "input_idx %= input_strides" << i << ";\n";
}
writer.block_end();
writer << "out[output_idx] = in[tid];\n";
}
writer.block_end();
......
......@@ -93,7 +93,9 @@ namespace ngraph
static void get_pad_dynamic_op(codegen::CodeWriter& writer,
const std::string& name,
const std::array<std::string, 2>& data_types);
GPUKernelArgs& args,
const std::array<std::string, 2>& data_types,
size_t rank);
static void get_ew_collective_op(codegen::CodeWriter& writer,
const std::string& name,
......
......@@ -1475,48 +1475,19 @@ namespace ngraph
{
// assumes NC{d1,d2,...} format
auto rep_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
bool in_place_op = (args[0].get_name() == out[0].get_name());
writer.block_begin();
{
auto& input_shape = args[0].get_shape();
auto& source_shape = args[1].get_shape();
auto& lower_bounds = rep_slice->get_lower_bounds();
auto& upper_bounds = rep_slice->get_upper_bounds();
auto& strides = rep_slice->get_strides();
Shape slice_shape(upper_bounds.size(), 0);
std::transform(upper_bounds.begin(),
upper_bounds.end(),
lower_bounds.begin(),
slice_shape.begin(),
std::minus<size_t>());
std::transform(slice_shape.begin(),
slice_shape.end(),
strides.begin(),
slice_shape.begin(),
std::divides<size_t>());
// replace the input with the source if the slice shape and input shape are equal
if (input_shape == slice_shape)
{
kernel::emit_memcpyDtD(writer, out[0], args[1]);
}
else
{
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
auto replace_slice_index = cuda_emitter->build_replace_slice(
{{args[0].get_type(), args[1].get_type(), out[0].get_type()}},
input_shape,
source_shape,
lower_bounds,
upper_bounds,
rep_slice->get_strides());
auto index = cuda_emitter->build_primitive(rep_slice, in_place_op);
writer << "gpu::invoke_primitive(ctx, " << replace_slice_index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ", "
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer << "gpu::invoke_primitive(ctx, " << index << ", ";
writer << "std::vector<void*>{" << args[0].get_name() << ", "
<< args[1].get_name() << "}.data(), ";
writer << "std::vector<void*>{" << out[0].get_name() << "}.data()";
writer << ");\n";
}
writer.block_end();
}
......
......@@ -105,6 +105,7 @@
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
......@@ -647,6 +648,7 @@ void runtime::gpu::GPU_ExternalFunction::compile()
m_pass_manager
.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
m_pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
m_pass_manager.register_pass<ngraph::pass::Liveness>();
m_pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/op/util/op_annotations.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
/// \brief Annotations added to graph ops by GPU backend passes
class GPUOpAnnotations : public ngraph::op::util::OpAnnotations
{
public:
GPUOpAnnotations() {}
};
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
#include "gpu_layout.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/gpu/gpu_op_annotations.hpp"
using namespace std;
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace pass
{
template <>
void GPULayout::LAYOUT_DECL(ngraph::op::ReplaceSlice)
{
auto rep_slice = static_cast<ngraph::op::ReplaceSlice*>(node.get());
auto op_annotations = rep_slice->get_op_annotations();
if (op_annotations)
{
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, true});
}
else
{
op_annotations = std::make_shared<ngraph::runtime::gpu::GPUOpAnnotations>();
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, true});
rep_slice->set_op_annotations(op_annotations);
}
}
}
}
}
}
#define TI(x) type_index(typeid(x))
static const runtime::gpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ReplaceSlice),
&runtime::gpu::pass::GPULayout::layout<ngraph::op::ReplaceSlice>},
};
bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{
for (const auto& node : nodes)
{
auto& n = *node;
auto handler = s_dispatcher.find(TI(n));
if (handler != s_dispatcher.end())
{
handler->second(m_external_function, node);
}
}
return false;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#define LAYOUT_DECL(op_type) \
layout<op_type>(ngraph::runtime::gpu::GPU_ExternalFunction * external_function, \
std::shared_ptr<ngraph::Node> node)
namespace ngraph
{
namespace runtime
{
namespace gpu
{
namespace pass
{
using LayoutFunction =
std::function<void(GPU_ExternalFunction*, std::shared_ptr<ngraph::Node>)>;
using LayoutOpMap = std::unordered_map<std::type_index, LayoutFunction>;
class GPULayout : public ngraph::pass::CallGraphPass
{
public:
GPULayout(GPU_ExternalFunction* external_function)
: m_external_function(external_function)
{
}
virtual bool
run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
template <typename OP>
static void
layout(ngraph::runtime::gpu::GPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node);
private:
GPU_ExternalFunction* m_external_function;
};
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment