Commit 757621be authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

nvgpu backend without clang (#2115)

* Separate out external function base class.

* pt1 first step to removing m_writer from GPU_Emitter.

* pt2 add gpu_internal function skeleton

* pt3 temporarily add to gpu_backend for prototyping.

* pt4 add call frame (partial) and runtime constructor

* pt 5 implement resolution for function memory reservations. build new tensor wrapper for use with call frame.

* pt 6 resolve compilation errors.

* pt 7 Add host emitter for emitting host primtives and implement in gpu emitter.

* pt 8 add compile time manifest.

* pt 9 add simple runtime tracer.

* pt 10 seperate runtimes for different functions. index by function name, should switch to using function instance_id for look up performance.

* pt 11 add function call interface and support nested call frames

* pt 12 Reshape elimination check in emitter needs to include offset.

* pt 13 Add default indentation to all op emissions in gpu external functions.

* pt 14 fix constant mem reservation (should not depend on the tmeporary buffers existence check.

* pt 15 backward pooling for avg pool requires only one param. rather than passing this param
three times, this commit changes the runtime to detect if its avgpooling and pass the appropriate pointers.
This is a hold over until max and avgpool are refactored into separate cudnn emitters.

* pt 16 update cmake compatibility. gpu backend can now be built without clang via NGRAPH_DEX_ONLY.
if this cmake variable is not define, then both clang codegen (via gpu external function) and interpreter (via gpu internal function) modes will be built.
for now codegen is the default backend but can be explicitly disabled by setting the env. variable to NGRAPH_CODEGEN=0/FALSE/NO/etc.

additional note: made codegen::CodeWriter header-only so that it can be used independently of whether the clang codegen library is compiled.

* pt 17 fix issues with merge from master

* pt 18 factor compile function into a few virtual calls so that common passes can be added in a single location for both backends.

* pt 19 formatting

* Remove code_writer.cpp from cmake and disable (temporarily) some reduce tests that require changes to gpu_emitter.cpp

* Move call frame and runtime constructor implementations to source files.

* Use member m_common_function_string.

* Applying analogous bug fix as found in #2145

* Remove underscore from GPU_CompiledFunction, GPU_ExternalFunction, and GPU_InternalFunction.

* Made static members of GPUCompiledFunction static methods.

* Remove 'No' codegen options, use std::toupper and applied format

* review comments

* Remove vector overload for resolve inputs/outputs in GPUCallFrame.

* Remove diagnostic pragmas
parent 84d6ae08
......@@ -108,7 +108,7 @@ option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE
option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE)
option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE)
option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE)
option(NGRAPH_DEX_ONLY "Build CPU DEX without codegen" FALSE)
option(NGRAPH_DEX_ONLY "Build CPU or GPU DEX without codegen" FALSE)
option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE)
option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE)
option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE)
......
......@@ -15,13 +15,12 @@
# ******************************************************************************
if (NGRAPH_GPU_ENABLE OR (NGRAPH_CPU_ENABLE AND NOT NGRAPH_DEX_ONLY))
if ((NGRAPH_GPU_ENABLE OR NGRAPH_CPU_ENABLE) AND NOT NGRAPH_DEX_ONLY)
if (DEFINED NGRAPH_USE_CXX_ABI AND NGRAPH_USE_PREBUILT_LLVM)
message(FATAL_ERROR "Unable to use NGRAPH_USE_PREBUILT_LLVM with NGRAPH_USE_CXX_ABI")
endif()
set(SRC
code_writer.cpp
compiler.cpp
execution_engine.cpp
)
......
......@@ -30,13 +30,14 @@ namespace ngraph
class ngraph::codegen::CodeWriter
{
public:
CodeWriter();
std::string get_code() const;
void operator+=(const std::string&);
size_t indent;
CodeWriter()
: indent(0)
, m_pending_indent(true)
, m_temporary_name_count(0)
{
}
std::string get_code() const { return m_ss.str(); }
void operator+=(const std::string& s) { *this << s; }
template <typename T>
friend CodeWriter& operator<<(CodeWriter& out, const T& obj)
{
......@@ -66,7 +67,15 @@ public:
return out;
}
std::string generate_temporary_name(std::string prefix = "tempvar");
std::string generate_temporary_name(const std::string& prefix = "tempvar")
{
std::stringstream ss;
ss << prefix << m_temporary_name_count;
m_temporary_name_count++;
return ss.str();
}
void block_begin()
{
......@@ -80,6 +89,8 @@ public:
*this << "}\n";
}
size_t indent;
private:
std::stringstream m_ss;
bool m_pending_indent;
......
......@@ -22,30 +22,39 @@ set(SRC
cuda_emitter.cpp
cudnn_emitter.cpp
cublas_emitter.cpp
host_emitter.cpp
gpu_backend.cpp
gpu_call_frame.cpp
gpu_cuda_context_manager.cpp
gpu_cuda_function_builder.cpp
gpu_cuda_function_pool.cpp
gpu_cuda_kernel_builder.cpp
gpu_emitter.cpp
gpu_external_function.cpp
gpu_compiled_function.cpp
gpu_internal_function.cpp
gpu_invoke.cpp
gpu_kernel_args.cpp
gpu_kernel_emitters.cpp
gpu_memory_manager.cpp
gpu_primitive_emitter.cpp
gpu_runtime_constructor.cpp
gpu_runtime_context.cpp
gpu_tensor_wrapper.cpp
gpu_tensor.cpp
gpu_util.cpp
type_info.cpp
pass/gpu_batch_norm_cache.cpp
pass/gpu_layout.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp
pass/gpu_batch_norm_cache.cpp
pass/tensor_memory_reservation.cpp
op/batch_norm.cpp
op/rnn.cpp
)
if (NOT NGRAPH_DEX_ONLY)
list(APPEND SRC gpu_external_function.cpp)
endif()
set(CUDA_INC
${PROJECT_SOURCE_DIR}/src/
)
......@@ -141,7 +150,12 @@ if (NGRAPH_GPU_ENABLE)
VERSION ${NGRAPH_VERSION}
SOVERSION ${NGRAPH_API_VERSION})
endif()
target_link_libraries(gpu_backend PUBLIC ngraph codegen)
target_link_libraries(gpu_backend PUBLIC ngraph)
if (NGRAPH_DEX_ONLY)
target_compile_definitions(gpu_backend PRIVATE "NGRAPH_DEX_ONLY")
else()
target_link_libraries(gpu_backend PUBLIC codegen)
endif()
find_library(CUDA_nvrtc_LIBRARY nvrtc
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
find_library(CUDA_cuda_LIBRARY cuda
......
......@@ -1688,20 +1688,26 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
{
pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
void* y = inputs[0];
void* dy = inputs[0];
void* x = inputs[0];
if (pool_op == 0 || pool_op == 3)
{
y = inputs[2];
dy = inputs[1];
x = inputs[0];
}
CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle,
desc,
alpha,
// output (wrt maxpool) tensor
output_desc,
inputs[2],
// adjoint of output
y,
output_desc,
inputs[1],
// input (wrt maxpool) tensor
dy,
input_desc,
inputs[0],
x,
beta,
// adjoint of input
// adjoint of input (dx)
input_desc,
outputs[0]));
debug_sync();
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp"
......@@ -120,13 +121,13 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor(
runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr)
if (instance.m_compiled_function == nullptr)
{
m_context->bind_cuda_context_to_thread();
instance.m_external_function = make_shared<GPU_ExternalFunction>(func, m_context);
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile();
instance.m_compiled_function = instance.m_external_function->m_compiled_function;
instance.m_compiled_function = runtime::gpu::GPUCompiledFunction::make(func, m_context);
instance.m_compiled_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_compiled_function->compile();
instance.m_runtime = instance.m_compiled_function->m_runtime;
instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size());
}
......@@ -156,7 +157,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr)
if (instance.m_compiled_function == nullptr)
{
throw runtime_error("compile() must be called before call().");
}
......@@ -169,7 +170,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
initialize_io(instance.m_outputs.data(), outputs);
auto ctx = m_context->m_runtime_context.get();
instance.m_compiled_function(instance.m_inputs.data(), instance.m_outputs.data(), ctx);
instance.m_runtime(instance.m_inputs.data(), instance.m_outputs.data(), ctx);
return true;
}
......@@ -182,7 +183,7 @@ void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> fu
void runtime::gpu::GPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable)
{
FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function != nullptr)
if (instance.m_compiled_function != nullptr)
{
throw runtime_error("Performance data collection must be enabled prior to compiling.");
}
......@@ -197,27 +198,9 @@ vector<runtime::PerformanceCounter>
if (it != m_function_map.end())
{
const FunctionInstance& instance = it->second;
if (instance.m_external_function != nullptr)
if (instance.m_compiled_function != nullptr)
{
auto* engine = instance.m_external_function->m_execution_engine.get();
if (engine)
{
auto get_count = engine->find_function<size_t()>("get_debug_timer_count");
auto get_name = engine->find_function<const char*(size_t)>("get_debug_timer_name");
auto get_microseconds =
engine->find_function<size_t(size_t)>("get_debug_timer_microseconds");
auto get_call_count =
engine->find_function<size_t(size_t)>("get_debug_timer_call_count");
if (get_count && get_name && get_microseconds && get_call_count)
{
size_t count = get_count();
for (size_t i = 0; i < count; i++)
{
rc.push_back({get_name(i), get_microseconds(i), get_call_count(i)});
}
}
}
instance.m_compiled_function->get_performance_data(rc);
}
}
return rc;
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
static size_t alignment = 64;
class GPU_ExternalFunction;
class GPUCompiledFunction;
class GPUPrimitiveEmitter;
struct GPURuntimeContext;
class CudaContextManager;
......@@ -83,9 +83,9 @@ namespace ngraph
class FunctionInstance
{
public:
std::shared_ptr<GPU_ExternalFunction> m_external_function;
std::shared_ptr<GPUCompiledFunction> m_compiled_function;
bool m_performance_counters_enabled = false;
EntryPoint m_compiled_function;
EntryPoint m_runtime;
std::vector<void*> m_inputs;
std::vector<void*> m_outputs;
};
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
using namespace ngraph;
runtime::gpu::GPUCallFrame::GPUCallFrame(const size_t& num_inputs, const size_t& num_outputs)
: m_inputs(num_inputs, nullptr)
, m_outputs(num_outputs, nullptr)
{
}
void runtime::gpu::GPUCallFrame::resolve_reservations(
const GPUCompiledFunction* compiled_function,
const std::unordered_map<std::string, size_t>& memory_reservations)
{
auto& mem_primitives = compiled_function->get_primitive_emitter()->get_memory_primitives();
for (auto const& p : memory_reservations)
{
// mem_primitives may return pointers for constant or workspace reservations
m_memory_reservations[p.first] = static_cast<unsigned char*>(mem_primitives.at(p.second)());
}
}
void runtime::gpu::GPUCallFrame::resolve_inputs(void** inputs, size_t num_inputs)
{
// num_inputs is > 0 iff we are resolving inputs from a nested function call
if (num_inputs == 0)
{
num_inputs = m_inputs.size();
}
for (size_t i = 0; i < num_inputs; i++)
{
void* input = inputs[i];
m_inputs[i] = static_cast<unsigned char*>(input);
}
}
void runtime::gpu::GPUCallFrame::resolve_outputs(void** outputs, size_t num_outputs)
{
// num_outputs is > 0 iff we are resolving outputs from a nested function call
if (num_outputs == 0)
{
num_outputs = m_outputs.size();
}
for (size_t i = 0; i < num_outputs; i++)
{
void* output = outputs[i];
m_outputs[i] = static_cast<unsigned char*>(output);
}
}
// returns pointers of any GPUTensorWrapper::TensorType
std::vector<void*>
runtime::gpu::GPUCallFrame::get_tensor_io(const std::vector<GPUTensorWrapper>& tensors)
{
std::vector<void*> ptrs;
for (auto const& tensor : tensors)
{
auto offset = tensor.get_offset();
auto ptr = get_pointer(offset.first, offset.second, tensor.get_name());
ptrs.push_back(ptr);
}
return ptrs;
}
void* runtime::gpu::GPUCallFrame::get_pointer(const TensorType& type,
const size_t& offset,
const std::string& name)
{
switch (type)
{
case TensorType::CONSTANT:
case TensorType::INTERMEDIATE:
return static_cast<void*>(m_memory_reservations.at(name) + offset);
case TensorType::INPUT: return static_cast<void*>(m_inputs.at(offset));
case TensorType::OUTPUT: return static_cast<void*>(m_outputs.at(offset));
case TensorType::UNKNOWN:
default: throw ngraph_error("GPUCallFrame encountered unknown or uninitialized tensor type");
};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <unordered_map>
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUCallFrame
{
public:
using TensorType = GPUTensorWrapper::TensorType;
GPUCallFrame(const size_t& num_inputs, const size_t& num_outputs);
void resolve_reservations(
const GPUCompiledFunction* compiled_function,
const std::unordered_map<std::string, size_t>& memory_reservations);
void resolve_inputs(void** inputs, size_t num_inputs = 0);
void resolve_outputs(void** outputs, size_t num_outputs = 0);
std::vector<void*> get_tensor_io(const std::vector<GPUTensorWrapper>& tensors);
private:
void* get_pointer(const TensorType& type,
const size_t& offset,
const std::string& name = "");
std::unordered_map<std::string, unsigned char*> m_memory_reservations;
std::vector<unsigned char*> m_inputs;
std::vector<unsigned char*> m_outputs;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <fstream>
#include <locale>
#include <mutex>
#include <string>
#include <tuple>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_batch_norm_cache.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
using namespace ngraph;
std::string runtime::gpu::GPUCompiledFunction::get_output_dir()
{
static std::string output_dir = "gpu_codegen";
return output_dir;
}
size_t runtime::gpu::GPUCompiledFunction::get_memory_alignment()
{
static size_t memory_pool_alignment = 64;
return memory_pool_alignment;
}
static std::mutex s_compilation;
class GPUStaticInitializers
{
public:
GPUStaticInitializers()
{
file_util::remove_directory(runtime::gpu::GPUCompiledFunction::get_output_dir());
file_util::make_directory(runtime::gpu::GPUCompiledFunction::get_output_dir());
}
};
static GPUStaticInitializers s_static_initializers;
runtime::gpu::GPUCompiledFunction::GPUCompiledFunction(
const shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
: m_runtime(nullptr)
, m_function(function)
, m_emit_timing(false)
, m_is_compiled(false)
, m_shared_context(shared_context)
{
}
runtime::gpu::GPUCompiledFunction::~GPUCompiledFunction()
{
}
std::vector<std::string> get_case_variants(std::vector<std::string> cases)
{
std::vector<std::string> results;
for (auto& c : cases)
{
results.push_back(c);
if (std::all_of(c.begin(), c.end(), ::isdigit))
{
continue;
}
for (auto i = 0u; i < c.size(); i++)
{
c[i] = std::toupper(c[i], std::locale());
if (i == 0)
{
results.emplace_back(c);
}
}
results.emplace_back(c);
}
return results;
}
std::shared_ptr<runtime::gpu::GPUCompiledFunction> runtime::gpu::GPUCompiledFunction::make(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
{
#if defined(NGRAPH_DEX_ONLY)
return std::make_shared<runtime::gpu::GPUInternalFunction>(function, shared_context);
#else
// For now codegen is default unless explicitly disabled
bool use_codegen = true;
if (auto env = std::getenv("NGRAPH_CODEGEN"))
{
std::string env_codegen(env);
for (auto& opt : get_case_variants({"0", "false"}))
{
if (env_codegen == opt)
{
use_codegen = false;
}
}
}
if (use_codegen)
{
return std::make_shared<runtime::gpu::GPUExternalFunction>(function, shared_context);
}
else
{
return std::make_shared<runtime::gpu::GPUInternalFunction>(function, shared_context);
}
#endif
}
void runtime::gpu::GPUCompiledFunction::compile()
{
if (m_is_compiled)
{
return;
}
std::unique_lock<std::mutex> lock(s_compilation);
m_function_name = m_function->get_name();
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
ngraph::pass::Manager pass_manager;
#if CUDNN_VERSION >= 7200
// recurrent network fusion
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
#else
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(get_memory_alignment());
pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
*allocator, m_tensor_memory_buffers);
string dump_filename = file_util::path_join(get_output_dir(), m_function_name + "_ops.txt");
pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
pass_manager.run_passes(m_function);
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
m_function_ordered_ops.emplace(current_function, current_function->get_ordered_ops());
}
add_passes(pass_manager);
emit();
// allocate device buffers for primitive arguments and workspace
allocator->close();
m_shared_context->m_primitive_emitter->allocate_primitive_memory();
compile_function();
m_is_compiled = true;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#define EMIT_ARGS \
runtime::gpu::GPUCompiledFunction *compiled_function, const std::string &function_name, \
const Node *node, const std::vector<runtime::gpu::GPUTensorWrapper> &args, \
const std::vector<runtime::gpu::GPUTensorWrapper> &out
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_Emitter;
struct GPURuntimeContext;
class GPUCompiledFunction
{
friend class GPU_Backend;
public:
GPUCompiledFunction(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUCompiledFunction();
static std::shared_ptr<GPUCompiledFunction>
make(const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
std::unique_ptr<runtime::gpu::GPURuntimeContext>& ctx();
const std::unique_ptr<GPUPrimitiveEmitter>& get_primitive_emitter() const
{
return m_shared_context->m_primitive_emitter;
}
virtual std::string
add_to_runtime(size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) = 0;
virtual std::string
add_call_to_runtime(const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) = 0;
void compile();
virtual void
get_performance_data(std::vector<runtime::PerformanceCounter>& rc) const = 0;
static size_t get_memory_alignment();
static std::string get_output_dir();
protected:
virtual void compile_function() = 0;
virtual void add_passes(ngraph::pass::Manager& pass_manager) = 0;
virtual void emit() = 0;
EntryPoint m_runtime;
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) = 0;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) = 0;
std::shared_ptr<ngraph::Function> m_function;
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>
m_function_ordered_ops;
bool m_emit_timing;
bool m_is_compiled;
size_t m_offset;
std::string m_function_name;
std::unordered_map<std::string, size_t> m_tensor_memory_buffers;
std::shared_ptr<GPU_Backend::BackendContext> m_shared_context;
};
}
}
}
......@@ -124,14 +124,14 @@ using namespace ngraph;
#define TI(x) type_index(typeid(x))
function<void(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Node& node)
function<std::string(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Node& node)
{
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// {<Abs typeid>, function<void(EMIT_ARGS)},
// {<Acos typeid>, function<void(EMIT_ARGS)},
// {<Abs typeid>, function<std::string(EMIT_ARGS)},
// {<Acos typeid>, function<std::string(EMIT_ARGS)},
// ...
#define NGRAPH_OP(a, b) {type_index(typeid(b::a)), runtime::gpu::GPU_Emitter::emit_##a},
static const map<type_index, function<void(EMIT_ARGS)>> typeid_map{
static const map<type_index, function<std::string(EMIT_ARGS)>> typeid_map{
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
};
#undef NGRAPH_OP
......@@ -144,60 +144,61 @@ function<void(EMIT_ARGS)> runtime::gpu::GPU_Emitter::get_emit_function(const Nod
return it->second;
}
void runtime::gpu::GPU_Emitter::emit_Abs(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Abs(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Abs>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Abs>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Acos(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Acos(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Acos>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Acos>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Add(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Add(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Add>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Add>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_All(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_All(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_AllReduce(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_AllReduce(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_And(EMIT_ARGS)
{
emit_elementwise<ngraph::op::And>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::And>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Any(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Any(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ArgMax(EMIT_ARGS)
{
cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MAX;
runtime::gpu::GPU_Emitter::emit_ArgReduce(
external_function, writer, node, args, out, reduce_op);
return runtime::gpu::GPU_Emitter::emit_ArgReduce(
compiled_function, function_name, node, args, out, reduce_op);
}
void runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ArgMin(EMIT_ARGS)
{
cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_MIN;
runtime::gpu::GPU_Emitter::emit_ArgReduce(
external_function, writer, node, args, out, reduce_op);
return runtime::gpu::GPU_Emitter::emit_ArgReduce(
compiled_function, function_name, node, args, out, reduce_op);
}
void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_op)
std::string runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t reduce_op)
{
if (out[0].get_size() == 0)
{
return;
// return;
return "";
}
size_t axis;
......@@ -219,135 +220,113 @@ void runtime::gpu::GPU_Emitter::emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
writer.block_begin();
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_reduce_forward(reduce_op,
dtypes,
args[0].get_shape(),
axis_set,
CUDNNEmitter::ReductionMode::ArgReduce);
auto index = cudnn_emitter->build_reduce_forward(
reduce_op, dtypes, args[0].get_shape(), axis_set, CUDNNEmitter::ReductionMode::ArgReduce);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Asin(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Asin(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Asin>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Asin>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Atan(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Atan(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Atan>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Atan>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_AvgPool(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_AvgPool(EMIT_ARGS)
{
// assumes NC{d1,d2,...} format
auto avg_pool = static_cast<const ngraph::op::AvgPool*>(node);
writer.block_begin();
{
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
auto& input_shape = args[0].get_shape();
auto& result_shape = out[0].get_shape();
auto padding_below = avg_pool->get_padding_below();
auto padding_above = avg_pool->get_padding_above();
size_t index = 0;
size_t index = 0;
// if 1d or has asymmetric padding, must handle pooling manually
if (input_shape.size() == 3 || padding_below != padding_above)
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
// if 1d or has asymmetric padding, must handle pooling manually
if (input_shape.size() == 3 || padding_below != padding_above)
{
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
index =
cuda_emitter->build_avg_pool({{args[0].get_type(), out[0].get_type()}},
index = cuda_emitter->build_avg_pool({{args[0].get_type(), out[0].get_type()}},
input_shape,
result_shape,
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
padding_below,
avg_pool->get_include_padding_in_avg_computation());
}
// 2d and 3d avg pool (NCHW) with either symetric padding or no padding
else if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto cudnn_avg_type = avg_pool->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_element_type(),
CUDNNEmitter::Prop::Forward,
input_shape,
result_shape,
avg_pool->get_window_movement_strides(),
avg_pool->get_window_shape(),
padding_below,
padding_above);
}
else
{
throw runtime_error("Pooling currently only supports up to 3 spatial dimensions.");
}
}
// 2d and 3d avg pool (NCHW) with either symetric padding or no padding
else if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
auto cudnn_avg_type = avg_pool->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_element_type(),
CUDNNEmitter::Prop::Forward,
input_shape,
result_shape,
avg_pool->get_window_movement_strides(),
avg_pool->get_window_shape(),
padding_below,
padding_above);
}
else
{
throw runtime_error("Pooling currently only supports up to 3 spatial dimensions.");
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_AvgPoolBackprop(EMIT_ARGS)
{
writer.block_begin();
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto output_shape = out[0].get_shape();
auto delta_shape = args[0].get_shape();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
if (output_shape.size() >= 4)
{
auto apb = static_cast<const ngraph::op::AvgPoolBackprop*>(node);
auto output_shape = out[0].get_shape();
auto delta_shape = args[0].get_shape();
auto cudnn_avg_type = apb->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_element_type(),
CUDNNEmitter::Prop::Backward,
output_shape,
delta_shape,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
if (output_shape.size() >= 4)
{
auto cudnn_avg_type = apb->get_include_padding_in_avg_computation()
? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
auto index = cudnn_emitter->build_pooling(cudnn_avg_type,
out[0].get_element_type(),
CUDNNEmitter::Prop::Backward,
output_shape,
delta_shape,
apb->get_window_movement_strides(),
apb->get_window_shape(),
apb->get_padding_below(),
apb->get_padding_above());
// cuDNN backwards pooling requests input and output tensors from
// the forward pass but does not use them. It also behaves differently
// for max pool vs avg pool. The repetition of args below is to address
// this interface in a way that supports both max and avg pooling
writer << "void* input[] = {" << node_names(args, {0, 0, 0}) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
return compiled_function->add_to_runtime(index, function_name, args, out);
}
else
{
throw ngraph_error("AvgPoolBackprop currently only supports tensors of rank 4 and greater");
}
writer.block_end();
}
template <typename T>
void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
std::string emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool save_stats)
{
const T* batchnorm = static_cast<const T*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
bool global_stats = false;
if (direction == runtime::gpu::CUDNNEmitter::Prop::Forward)
......@@ -364,39 +343,33 @@ void emit_BatchNorm(EMIT_ARGS, runtime::gpu::CUDNNEmitter::Prop direction, bool
global_stats,
save_stats);
writer.block_begin();
{
writer << "void* input[] = {" << runtime::gpu::GPU_Emitter::node_names(args) << "};\n";
writer << "void* output[] = {" << runtime::gpu::GPU_Emitter::node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_BatchNormInference(EMIT_ARGS)
{
::emit_BatchNorm<ngraph::op::BatchNormInference>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Inference, false);
return ::emit_BatchNorm<ngraph::op::BatchNormInference>(
compiled_function, function_name, node, args, out, CUDNNEmitter::Prop::Inference, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_BatchNormTraining(EMIT_ARGS)
{
::emit_BatchNorm<ngraph::op::BatchNormTraining>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, false);
return ::emit_BatchNorm<ngraph::op::BatchNormTraining>(
compiled_function, function_name, node, args, out, CUDNNEmitter::Prop::Forward, false);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_BatchNormTrainingWithStats(EMIT_ARGS)
{
::emit_BatchNorm<ngraph::op::gpu::BatchNormTrainingWithStats>(
external_function, writer, node, args, out, CUDNNEmitter::Prop::Forward, true);
return ::emit_BatchNorm<ngraph::op::gpu::BatchNormTrainingWithStats>(
compiled_function, function_name, node, args, out, CUDNNEmitter::Prop::Forward, true);
}
void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
{
const ngraph::op::BatchNormTrainingBackprop* batchnorm =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
bool needs_variance_inversion = false;
auto annotation = batchnorm->get_op_annotations();
......@@ -418,63 +391,52 @@ void runtime::gpu::GPU_Emitter::emit_BatchNormTrainingBackprop(EMIT_ARGS)
false,
false,
needs_variance_inversion);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Broadcast(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Broadcast(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto broadcast = static_cast<const ngraph::op::Broadcast*>(node);
auto arg_shape = args[0].get_shape();
auto result_shape = out[0].get_shape();
auto& axes = broadcast->get_broadcast_axes();
size_t index;
// broadcast axes is empty, do a copy
if (axes.empty())
{
writer.block_begin();
kernel::emit_memcpyDtD(writer, out[0], args[0]);
writer.block_end();
return;
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto bcast_index = cuda_emitter->build_broadcast(
{{args[0].get_type(), out[0].get_type()}}, result_shape, axes);
writer.block_begin();
else
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << bcast_index << ", input, output);\n";
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_broadcast(
{{args[0].get_type(), out[0].get_type()}}, result_shape, axes);
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_BroadcastLike(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_BroadcastLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Ceiling(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Ceiling(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Ceiling>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Ceiling>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Concat(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Concat(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto concat = static_cast<const ngraph::op::Concat*>(node);
auto axis = concat->get_concatenation_axis();
......@@ -485,63 +447,53 @@ void runtime::gpu::GPU_Emitter::emit_Concat(EMIT_ARGS)
input_shapes.push_back(arg.get_shape());
}
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto index =
cuda_emitter->build_concat(out[0].get_type(), input_shapes, axis, out[0].get_shape());
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Constant(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Constant(EMIT_ARGS)
{
return "";
}
void runtime::gpu::GPU_Emitter::emit_Convert(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Convert(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Convert>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Convert>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Convolution(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Convolution(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto convolution = static_cast<const ngraph::op::Convolution*>(node);
size_t conv_index = 0;
size_t index = 0;
if (convolution->get_padding_below().size() > 3)
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
conv_index = cuda_emitter->build_primitive(convolution);
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_primitive(convolution);
}
else
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
conv_index = cudnn_emitter->build_primitive(convolution);
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
index = cudnn_emitter->build_primitive(convolution);
}
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << conv_index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropData(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropData(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
// return;
return "";
}
auto convolution = static_cast<const ngraph::op::ConvolutionBackpropData*>(node);
......@@ -551,23 +503,18 @@ void runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropData(EMIT_ARGS)
throw runtime_error(node->get_name() + "with more than 3D is not implemented.");
}
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t conv_index = cudnn_emitter->build_primitive(convolution);
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_primitive(convolution);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << conv_index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropFilters(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropFilters(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
// return;
return "";
}
auto convolution = static_cast<const ngraph::op::ConvolutionBackpropFilters*>(node);
......@@ -577,43 +524,37 @@ void runtime::gpu::GPU_Emitter::emit_ConvolutionBackpropFilters(EMIT_ARGS)
throw runtime_error(node->get_name() + "with more than 3D is not implemented.");
}
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
size_t conv_index = cudnn_emitter->build_primitive(convolution);
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_primitive(convolution);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << conv_index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Cos(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Cos(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Cos>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Cos>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Cosh(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Cosh(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Cosh>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Cosh>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Divide(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Divide(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Divide>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Divide>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Dequantize(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Dequantize(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto dot = static_cast<const ngraph::op::Dot*>(node);
size_t reduction_axes_count = dot->get_reduction_axes_count();
......@@ -621,116 +562,103 @@ void runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
const Shape& arg1_shape = args[1].get_shape();
const Shape& out_shape = out[0].get_shape();
writer.block_begin();
size_t index;
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
// set output to 0 if input size is 0
if (args[0].get_size() == 0 || args[1].get_size() == 0)
{
writer << "runtime::gpu::cuda_memset(" << out[0].get_name() << ", 0, "
<< out[0].get_size() << " * " << out[0].get_element_type().size() << ");\n";
}
else
{
auto& cublas_emitter = external_function->get_primitive_emitter()->get_cublas_emitter();
auto index = cublas_emitter->build_dot(out[0].get_element_type(),
arg0_shape,
arg1_shape,
out_shape,
reduction_axes_count,
node);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index =
host_emitter->build_zero_out(0, out[0].get_size() * out[0].get_element_type().size());
}
writer.block_end();
else
{
auto& cublas_emitter = compiled_function->get_primitive_emitter()->get_cublas_emitter();
index = cublas_emitter->build_dot(out[0].get_element_type(),
arg0_shape,
arg1_shape,
out_shape,
reduction_axes_count,
node);
}
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_EmbeddingLookup(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_EmbeddingLookup(EMIT_ARGS)
{
throw ngraph_error("EmbeddingLookup is not yet implemented for NVIDIA GPU");
}
void runtime::gpu::GPU_Emitter::emit_Equal(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Equal(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Equal>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Equal>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Exp(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Exp(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Exp>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Exp>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Floor(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Floor(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Floor>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Floor>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_FunctionCall(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_FunctionCall(EMIT_ARGS)
{
auto function_call = static_cast<const ngraph::op::FunctionCall*>(node);
shared_ptr<Function> function = function_call->get_functions()[0];
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << function->get_name() << "(input, output, ctx);\n";
}
writer.block_end();
return compiled_function->add_call_to_runtime(function_name, function->get_name(), args, out);
}
void runtime::gpu::GPU_Emitter::emit_GenerateMask(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_GenerateMask(EMIT_ARGS)
{
throw ngraph_error("GenerateMask is not supported yet on NVIDIA GPU");
}
void runtime::gpu::GPU_Emitter::emit_GetOutputElement(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_GetOutputElement(EMIT_ARGS)
{
auto get_tuple_element = static_cast<const ngraph::op::GetOutputElement*>(node);
writer.block_begin();
writer << "runtime::gpu::cuda_memcpyDtD(" << out[0].get_name() << ", "
<< args[get_tuple_element->get_n()].get_name() << ", "
<< out[0].get_size() * out[0].get_element_type().size() << ");\n";
writer.block_end();
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
size_t index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size(),
0,
get_tuple_element->get_n());
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Greater(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Greater(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Greater>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Greater>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_GreaterEq(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_GreaterEq(EMIT_ARGS)
{
emit_elementwise<ngraph::op::GreaterEq>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::GreaterEq>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Less(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Less(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Less>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Less>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_LessEq(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_LessEq(EMIT_ARGS)
{
emit_elementwise<ngraph::op::LessEq>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::LessEq>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Log(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Log(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Log>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Log>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_LRN(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_LRN(EMIT_ARGS)
{
auto lrn = static_cast<const ngraph::op::LRN*>(node);
auto& input_shape = args[0].get_shape();
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_lrn(out[0].get_type(),
CUDNNEmitter::Prop::Forward,
input_shape,
......@@ -738,42 +666,33 @@ void runtime::gpu::GPU_Emitter::emit_LRN(EMIT_ARGS)
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Max(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Max(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
const ngraph::op::Max* max = static_cast<const ngraph::op::Max*>(node);
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_reduce<ngraph::op::Max>(
dtypes, args[0].get_shape(), out[0].get_shape(), max->get_reduction_axes());
writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Maximum(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Maximum(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Maximum>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Maximum>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_MaxPool(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_MaxPool(EMIT_ARGS)
{
// assumes NC{d1,d2,...} format
auto max_pool = static_cast<const ngraph::op::MaxPool*>(node);
......@@ -793,113 +712,110 @@ void runtime::gpu::GPU_Emitter::emit_MaxPool(EMIT_ARGS)
throw runtime_error("Pooling currently only supports up to 3 spatial dimensions.");
}
size_t max_pool_index;
size_t index;
// 1d max pool (NCW)
if (input_shape.size() == 3)
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
max_pool_index = cuda_emitter->build_primitive(max_pool);
index = cuda_emitter->build_primitive(max_pool);
}
// 2d and 3d max pool (NCHW)
else if (input_shape.size() == 4 || input_shape.size() == 5)
{
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
max_pool_index = cudnn_emitter->build_primitive(max_pool);
index = cudnn_emitter->build_primitive(max_pool);
}
writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << max_pool_index << ", input, output);\n";
writer.block_end();
else
{
throw ngraph_error("Unsupported tensor rank encountered in " + node->description());
}
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_MaxPoolBackprop(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_MaxPoolBackprop(EMIT_ARGS)
{
writer.block_begin();
{
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto fp_input_shape = out[0].get_shape();
auto fp_output_shape = args[1].get_shape();
auto mpb = static_cast<const ngraph::op::MaxPoolBackprop*>(node);
auto fp_input_shape = out[0].get_shape();
auto fp_output_shape = args[1].get_shape();
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
bool needs_fprop = (args.size() != 3);
if (fp_input_shape.size() >= 4)
{
auto index = cudnn_emitter->build_pooling(CUDNN_POOLING_MAX,
out[0].get_element_type(),
CUDNNEmitter::Prop::Backward,
fp_input_shape,
fp_output_shape,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above(),
needs_fprop);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
bool needs_fprop = (args.size() != 3);
if (fp_input_shape.size() >= 4)
{
auto index = cudnn_emitter->build_pooling(CUDNN_POOLING_MAX,
out[0].get_element_type(),
CUDNNEmitter::Prop::Backward,
fp_input_shape,
fp_output_shape,
mpb->get_window_movement_strides(),
mpb->get_window_shape(),
mpb->get_padding_below(),
mpb->get_padding_above(),
needs_fprop);
return compiled_function->add_to_runtime(index, function_name, args, out);
}
else
{
throw ngraph_error("Unsupported tensor rank encountered in " + node->description());
}
writer.block_end();
}
void runtime::gpu::GPU_Emitter::emit_Min(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Min(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
const ngraph::op::Min* min = static_cast<const ngraph::op::Min*>(node);
size_t index;
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_reduce<ngraph::op::Min>(
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_reduce<ngraph::op::Min>(
dtypes, args[0].get_shape(), out[0].get_shape(), min->get_reduction_axes());
writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Minimum(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Minimum(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Minimum>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Minimum>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Multiply(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Multiply(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Multiply>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Multiply>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Negative(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Negative(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Negative>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Negative>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Not(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Not(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Not>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Not>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_NotEqual(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_NotEqual(EMIT_ARGS)
{
emit_elementwise<ngraph::op::NotEqual>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::NotEqual>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_OneHot(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_OneHot(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto onehot = static_cast<const ngraph::op::OneHot*>(node);
auto arg_shape = args[0].get_shape();
......@@ -907,170 +823,151 @@ void runtime::gpu::GPU_Emitter::emit_OneHot(EMIT_ARGS)
auto output_datatype_size = out[0].get_element_type().size();
size_t idx = onehot->get_one_hot_axis();
writer.block_begin();
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_onehot({{args[0].get_type(), out[0].get_type()}},
arg_shape,
result_shape,
idx,
output_datatype_size);
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_onehot({{args[0].get_type(), out[0].get_type()}},
arg_shape,
result_shape,
idx,
output_datatype_size);
writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
writer.block_end();
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Or(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Or(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Or>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Or>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Pad(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Pad(EMIT_ARGS)
{
auto pad = static_cast<const ngraph::op::Pad*>(node);
writer.block_begin();
{
auto input_shape = args[0].get_shape();
auto output_shape = out[0].get_shape();
auto padding_below = pad->get_padding_below();
auto padding_above = pad->get_padding_above();
auto padding_interior = pad->get_padding_interior();
auto input_shape = args[0].get_shape();
auto output_shape = out[0].get_shape();
auto padding_below = pad->get_padding_below();
auto padding_above = pad->get_padding_above();
auto padding_interior = pad->get_padding_interior();
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto pad_index = cuda_emitter->build_pad_fill(
{{args[0].get_type(), args[1].get_type(), out[0].get_type()}},
input_shape,
output_shape,
padding_below,
padding_interior);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << pad_index << ", input, output);\n";
}
writer.block_end();
auto index =
cuda_emitter->build_pad_fill({{args[0].get_type(), args[1].get_type(), out[0].get_type()}},
input_shape,
output_shape,
padding_below,
padding_interior);
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Parameter(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Parameter(EMIT_ARGS)
{
return "";
}
void runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Power(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Power>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Power>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Product(EMIT_ARGS)
{
const ngraph::op::Product* prod = static_cast<const ngraph::op::Product*>(node);
writer.block_begin();
if (out[0].get_size() == 0)
{
if (out[0].get_size() != 0)
{
size_t prod_index;
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
prod_index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, args[0].get_shape(), out[0].get_shape(), prod->get_reduction_axes());
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << prod_index << ", input, output);\n";
}
return "";
}
writer.block_end();
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, args[0].get_shape(), out[0].get_shape(), prod->get_reduction_axes());
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Quantize(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Quantize(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Reduce(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return "";
}
const ngraph::op::Reduce* reduce_op = static_cast<const ngraph::op::Reduce*>(node);
writer.block_begin();
auto axes_set = reduce_op->get_reduction_axes();
std::vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto reduction_function_ops = reduce_op->get_functions()[0]->get_ops();
size_t emitter_index;
// Reduction function should only have one op
std::shared_ptr<Node> reduce_func;
std::string op_name;
int op_count = 0;
for (auto op : reduction_function_ops)
{
if (out[0].get_size() != 0)
if (op->is_constant() || op->is_parameter() || op->is_output())
{
auto axes_set = reduce_op->get_reduction_axes();
std::vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto reduction_function_ops = reduce_op->get_functions()[0]->get_ops();
size_t emitter_index;
// Reduction function should only have one op
std::shared_ptr<Node> reduce_func;
std::string op_name;
int op_count = 0;
for (auto op : reduction_function_ops)
{
if (op->is_constant() || op->is_parameter() || op->is_output())
{
continue;
}
op_count++;
op_name = op->get_name();
reduce_func = op;
if (op_count != 1)
{
throw runtime_error("reduce with more than one op is not implement yet.");
}
}
if (dynamic_pointer_cast<ngraph::op::Add>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Add>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Multiply>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Maximum>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Maximum>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Minimum>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Minimum>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::And>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::And>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Or>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Or>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else
{
throw runtime_error("reduce with function " + op_name + " is not implement yet.");
}
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << emitter_index << ", input, output);\n";
continue;
}
op_count++;
op_name = op->get_name();
reduce_func = op;
if (op_count != 1)
{
throw runtime_error("reduce with more than one op is not implement yet.");
}
}
writer.block_end();
if (dynamic_pointer_cast<ngraph::op::Add>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Add>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Multiply>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Multiply>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Maximum>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Maximum>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Minimum>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Minimum>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::And>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::And>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else if (dynamic_pointer_cast<ngraph::op::Or>(reduce_func))
{
emitter_index = cuda_emitter->build_reduce<ngraph::op::Or>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set, true);
}
else
{
throw runtime_error("reduce with function " + op_name + " is not implement yet.");
}
return compiled_function->add_to_runtime(emitter_index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ReduceWindow(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ReduceWindow(EMIT_ARGS)
{
static const unordered_map<type_index, ngraph::runtime::gpu::OpName> reduce_window_map{
{TI(ngraph::op::Add), ngraph::runtime::gpu::OpName::add},
......@@ -1080,130 +977,113 @@ void runtime::gpu::GPU_Emitter::emit_ReduceWindow(EMIT_ARGS)
const ngraph::op::ReduceWindow* reduce_window_op =
static_cast<const ngraph::op::ReduceWindow*>(node);
writer.block_begin();
if (out[0].get_size() == 0)
{
return "";
}
size_t index;
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
{
throw ngraph_error("ReduceWindow for zero-size input is not currently supported");
}
else if (args[0].get_size() == out[0].get_size())
{
if (out[0].get_size() != 0)
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
else
{
// in current implementation:
// 1. reduction function should only have one op
// 2. the op should be in the op_map
// otherwise, throw an error message
auto reduction_function_ops = reduce_window_op->get_functions()[0]->get_ops();
unordered_map<type_index, ngraph::runtime::gpu::OpName>::const_iterator it =
reduce_window_map.end();
int op_count = 0;
for (auto op : reduction_function_ops)
{
// one of args0 axes has zero size, zero output, use args1 value
if (args[0].get_size() == 0)
if (op->is_constant() || op->is_parameter() || op->is_output())
{
writer << out[0].get_type() << " init_value;\n";
writer << "runtime::gpu::cuda_memcpyDtH(&init_value, " << args[1].get_name() << " ,"
<< args[1].get_element_type().size() << ");\n";
writer << "vector<" << out[0].get_type() << "> temp(" << out[0].get_size()
<< ", init_value);\n";
writer << "runtime::gpu::cuda_memcpyHtD(" << out[0].get_name()
<< ", (void*)temp.data(), " << out[0].get_size() << " * "
<< out[0].get_element_type().size() << ");\n";
continue;
}
else if (args[0].get_size() == out[0].get_size())
op_count++;
// Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto& fn = *op;
auto f_ptr = reduce_window_map.find(type_index(typeid(fn)));
if (op_count != 1)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
throw runtime_error("reduce with more than one op is not implement yet.");
}
else if (f_ptr == reduce_window_map.end())
{
throw runtime_error("reduce with function " + fn.get_name() +
" is not implement yet.");
}
else
{
// in current implementation:
// 1. reduction function should only have one op
// 2. the op should be in the op_map
// otherwise, throw an error message
auto reduction_function_ops = reduce_window_op->get_functions()[0]->get_ops();
unordered_map<type_index, ngraph::runtime::gpu::OpName>::const_iterator it =
reduce_window_map.end();
int op_count = 0;
for (auto op : reduction_function_ops)
{
if (op->is_constant() || op->is_parameter() || op->is_output())
{
continue;
}
op_count++;
// Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.)
auto& fn = *op;
auto f_ptr = reduce_window_map.find(type_index(typeid(fn)));
if (op_count != 1)
{
throw runtime_error("reduce with more than one op is not implement yet.");
}
else if (f_ptr == reduce_window_map.end())
{
throw runtime_error("reduce with function " + fn.get_name() +
" is not implement yet.");
}
else
{
it = f_ptr;
}
}
if (it == reduce_window_map.end())
{
throw runtime_error("no valid op found in reduction function.");
}
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t reduce_index;
// this dtypes is two build the binary op, expect both input has same type with args[0]
vector<string> dtypes{args[0].get_type(), args[0].get_type(), out[0].get_type()};
reduce_index = cuda_emitter->build_reduce_window(
it->second,
dtypes,
args[0].get_shape(),
out[0].get_shape(),
reduce_window_op->get_window_shape(),
reduce_window_op->get_window_movement_strides());
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << reduce_index << ", input, output);\n";
it = f_ptr;
}
}
if (it == reduce_window_map.end())
{
throw runtime_error("no valid op found in reduction function.");
}
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
// this dtypes is two build the binary op, expect both input has same type with args[0]
vector<string> dtypes{args[0].get_type(), args[0].get_type(), out[0].get_type()};
index = cuda_emitter->build_reduce_window(it->second,
dtypes,
args[0].get_shape(),
out[0].get_shape(),
reduce_window_op->get_window_shape(),
reduce_window_op->get_window_movement_strides());
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Relu(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Relu(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Relu>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Relu>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ReluBackprop(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ReluBackprop(EMIT_ARGS)
{
emit_elementwise<ngraph::op::ReluBackprop>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::ReluBackprop>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ReplaceSlice(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ReplaceSlice(EMIT_ARGS)
{
// assumes NC{d1,d2,...} format
auto rep_slice = static_cast<const ngraph::op::ReplaceSlice*>(node);
bool in_place_op = (args[0].get_name() == out[0].get_name());
writer.block_begin();
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_primitive(rep_slice, in_place_op);
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_primitive(rep_slice, in_place_op);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Reshape(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Reshape(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto reshape = static_cast<const op::Reshape*>(node);
if (out[0].get_name() == args[0].get_name())
if (out[0].get_name() == args[0].get_name() && out[0].get_offset() == args[0].get_offset())
{
writer << "// Logical reshape eliminated\n";
return;
return "// Logical reshape eliminated\n";
}
auto arg_shape = args[0].get_shape();
......@@ -1215,12 +1095,10 @@ void runtime::gpu::GPU_Emitter::emit_Reshape(EMIT_ARGS)
//for a zero-size tensor, or change from 1^m shape to 1^n shape, just do a copy
if (!reshape->get_is_transpose() || result_shape_product < 2)
{
writer.block_begin();
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
writer.block_end();
return;
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
size_t index = host_emitter->build_memcpy(
cudaMemcpyDeviceToDevice, out[0].get_size() * out[0].get_element_type().size());
return compiled_function->add_to_runtime(index, function_name, args, out);
}
//combine inordered dimensons after reorder in shape, update output shape and input order
......@@ -1286,62 +1164,59 @@ void runtime::gpu::GPU_Emitter::emit_Reshape(EMIT_ARGS)
new_result_shape.push_back(new_arg_shape[new_input_order[i]]);
}
size_t index;
// If there is no layout change, we can just copy.
writer.block_begin();
bool same_layout = is_sorted(new_input_order.begin(), new_input_order.end());
if (same_layout)
{
bool same_layout = is_sorted(new_input_order.begin(), new_input_order.end());
if (same_layout)
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
// If there *is* a layout change in the 2D case, we transpose the input.
else
{
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
if (new_rank == 2)
{
index = cuda_emitter->build_reshape_2d(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
// If there *is* a layout change in the 3D case, we do 3D tiled reshape.
else if (new_rank == 3)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
index = cuda_emitter->build_reshape_3d(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
// If there *is* a layout change in the 2D case, we transpose the input.
// Other cases (reordering of axes for tensors with rank>3).
else
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index;
if (new_rank == 2)
{
index = cuda_emitter->build_reshape_2d(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
// If there *is* a layout change in the 3D case, we do 3D tiled reshape.
else if (new_rank == 3)
{
index = cuda_emitter->build_reshape_3d(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
// Other cases (reordering of axes for tensors with rank>3).
else
{
index = cuda_emitter->build_reshape(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
index = cuda_emitter->build_reshape(
{{args[0].get_type(), out[0].get_type()}}, new_arg_shape, new_input_order);
}
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Result(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Result(EMIT_ARGS)
{
if (args[0].get_name() == out[0].get_name())
{
writer << "// Skipping generation for " << node->get_name() << "\n";
return;
return "// Skipping generation for " + node->get_name() + "\n";
}
writer.block_begin();
kernel::emit_memcpyDtD(writer, out[0], args[0]);
writer.block_end();
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
size_t index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Reverse(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Reverse(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto reverse = static_cast<const op::Reverse*>(node);
......@@ -1354,29 +1229,27 @@ void runtime::gpu::GPU_Emitter::emit_Reverse(EMIT_ARGS)
{
reverse_axes_flag[a] = 1;
}
writer.block_begin();
size_t index;
if (out[0].get_size() == 1)
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
else
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_reverse(
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_reverse(
{{args[0].get_type(), out[0].get_type()}}, arg_shape, reverse_axes_flag);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto rs = static_cast<const ngraph::op::ReverseSequence*>(node);
......@@ -1386,90 +1259,79 @@ void runtime::gpu::GPU_Emitter::emit_ReverseSequence(EMIT_ARGS)
auto arg_shape1 = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto rs_index = cuda_emitter->build_reverse_sequence(
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_reverse_sequence(
{{args[0].get_type(), args[1].get_type(), out[0].get_type()}},
arg_shape0,
arg_shape1,
out_shape,
bi,
si);
writer.block_begin();
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << rs_index << ", input, output);\n";
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
#if CUDNN_VERSION >= 7200
void runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Rnn(EMIT_ARGS)
{
auto rnn = static_cast<const ngraph::op::gpu::Rnn*>(node);
auto& cudnn_emitter = external_function->get_primitive_emitter()->get_cudnn_emitter();
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
size_t index = cudnn_emitter->build_primitive(rnn);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
#endif
void runtime::gpu::GPU_Emitter::emit_ScalarConstantLike(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ScalarConstantLike(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Select(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Select>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Select>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_SelectAndScatter(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_SelectAndScatter(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_ShapeOf(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_ShapeOf(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Sigmoid(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sigmoid(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Sigmoid>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Sigmoid>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_SigmoidBackprop(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_SigmoidBackprop(EMIT_ARGS)
{
emit_elementwise<ngraph::op::SigmoidBackprop>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::SigmoidBackprop>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sign(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sign(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Sign>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Sign>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sin(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sin(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Sin>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Sin>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sinh(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sinh(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Sinh>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Sinh>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Slice(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Slice(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto slice = static_cast<const op::Slice*>(node);
......@@ -1478,149 +1340,137 @@ void runtime::gpu::GPU_Emitter::emit_Slice(EMIT_ARGS)
const Coordinate& lower_bounds = slice->get_lower_bounds();
const Strides slice_strides = slice->get_strides();
writer.block_begin();
size_t index;
if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
else
{
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_slice({{args[0].get_type(), out[0].get_type()}},
arg_shape,
lower_bounds,
slice_strides,
result_shape);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
index = cuda_emitter->build_slice({{args[0].get_type(), out[0].get_type()}},
arg_shape,
lower_bounds,
slice_strides,
result_shape);
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Softmax(EMIT_ARGS)
{
auto softmax = static_cast<const ngraph::op::Softmax*>(node);
writer.block_begin();
{
auto axes_set = softmax->get_axes();
std::vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_set);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
auto axes_set = softmax->get_axes();
std::vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
size_t index = cuda_emitter->build_softmax(dtypes, args[0].get_shape(), axes_set);
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sqrt(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sqrt(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Sqrt>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Sqrt>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_StopGradient(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_StopGradient(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
void runtime::gpu::GPU_Emitter::emit_Subtract(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Subtract(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Subtract>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Subtract>(
compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sum(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Sum(EMIT_ARGS)
{
runtime::gpu::GPU_Emitter::emit_Sum_0(external_function, writer, node, args, out);
return runtime::gpu::GPU_Emitter::emit_Sum_0(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Sum_0(EMIT_ARGS)
/* emit_Sum_0 uses native cuda kernels to perform Sum reduction. This method
is faster than cudnn implementation but in its current state is less precise
than cudnn reduce. That is causing tensorflow tests aimed at testing stabilty
to fail */
std::string runtime::gpu::GPU_Emitter::emit_Sum_0(EMIT_ARGS)
// emit_Sum_0 uses native cuda kernels to perform Sum reduction. This method
// is faster than cudnn implementation but in its current state is less precise
// than cudnn reduce. That is causing tensorflow tests aimed at testing stabilty
// to fail
{
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
writer.block_begin();
if (out[0].get_size() == 0)
{
if (out[0].get_size() != 0)
{
auto axes_set = sum->get_reduction_axes();
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto sum_index = cuda_emitter->build_reduce<ngraph::op::Add>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << sum_index << ", input, output);\n";
}
return "";
}
writer.block_end();
}
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
void runtime::gpu::GPU_Emitter::emit_Sum_1(EMIT_ARGS)
auto axes_set = sum->get_reduction_axes();
vector<element::Type> dtypes;
dtypes.push_back(args[0].get_element_type());
dtypes.push_back(out[0].get_element_type());
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto sum_index = cuda_emitter->build_reduce<ngraph::op::Add>(
dtypes, args[0].get_shape(), out[0].get_shape(), axes_set);
/* emit_Sum_1 uses cudnn to perform Sum reduction. This method, although
slower than the native cuda implementation is more precise and fixes the issue with
tensorflow test failures*/
return compiled_function->add_to_runtime(sum_index, function_name, args, out);
}
std::string runtime::gpu::GPU_Emitter::emit_Sum_1(EMIT_ARGS)
// emit_Sum_1 uses cudnn to perform Sum reduction. This method, although
// slower than the native cuda implementation is more precise and fixes the issue with
// tensorflow test failures
{
const ngraph::op::Sum* sum = static_cast<const ngraph::op::Sum*>(node);
std::vector<element::Type> dtypes{args[0].get_element_type(), out[0].get_element_type()};
cudnnReduceTensorOp_t reduce_op = CUDNN_REDUCE_TENSOR_ADD;
writer.block_begin();
if (out[0].get_size() == 0)
{
if (out[0].get_size() != 0)
{
// one of args[] axes has zero size, zero output
if (args[0].get_size() == 0)
{
kernel::emit_memset(writer, out[0], 0);
}
else if (args[0].get_size() == out[0].get_size())
{
kernel::emit_memcpyDtD(writer, out[0], args[0]);
}
else
{
auto& cudnn_emitter =
external_function->get_primitive_emitter()->get_cudnn_emitter();
auto sum_index =
cudnn_emitter->build_reduce_forward(reduce_op,
dtypes,
args[0].get_shape(),
sum->get_reduction_axes(),
CUDNNEmitter::ReductionMode::Reduce);
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << sum_index << ", input, output);\n";
}
}
return "";
}
size_t index;
// one of args[] axes has zero size, zero output
if (args[0].get_size() == 0)
{
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index =
host_emitter->build_zero_out(0, out[0].get_size() * out[0].get_element_type().size());
}
else if (args[0].get_size() == out[0].get_size())
{
auto& host_emitter = compiled_function->get_primitive_emitter()->get_host_emitter();
index = host_emitter->build_memcpy(cudaMemcpyDeviceToDevice,
out[0].get_size() * out[0].get_element_type().size());
}
writer.block_end();
else
{
auto& cudnn_emitter = compiled_function->get_primitive_emitter()->get_cudnn_emitter();
index = cudnn_emitter->build_reduce_forward(reduce_op,
dtypes,
args[0].get_shape(),
sum->get_reduction_axes(),
CUDNNEmitter::ReductionMode::Reduce);
}
return compiled_function->add_to_runtime(index, function_name, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Tan(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Tan(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Tan>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Tan>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_Tanh(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_Tanh(EMIT_ARGS)
{
emit_elementwise<ngraph::op::Tanh>(external_function, writer, node, args, out);
return emit_elementwise<ngraph::op::Tanh>(compiled_function, function_name, node, args, out);
}
void runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS)
std::string runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
auto topk = static_cast<const ngraph::op::TopK*>(node);
size_t topk_axis = topk->get_top_k_axis();
......@@ -1634,16 +1484,11 @@ void runtime::gpu::GPU_Emitter::emit_TopK(EMIT_ARGS)
dtypes.push_back(out[i].get_element_type());
}
auto& input_shape = args[0].get_shape();
auto& cuda_emitter = external_function->get_primitive_emitter()->get_cuda_emitter();
auto& cuda_emitter = compiled_function->get_primitive_emitter()->get_cuda_emitter();
auto index = cuda_emitter->build_topk(
dtypes, input_shape, topk_axis, topk_k, index_elem_type, compute_max);
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << index << ", input, output);\n";
}
writer.block_end();
return compiled_function->add_to_runtime(index, function_name, args, out);
}
string runtime::gpu::GPU_Emitter::node_names(const vector<GPUTensorWrapper>& args,
......
......@@ -19,9 +19,8 @@
#include <string>
#include <vector>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/node.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
......@@ -33,21 +32,21 @@ namespace ngraph
class GPU_Emitter
{
public:
static std::function<void(EMIT_ARGS)> get_emit_function(const Node& node);
static std::function<std::string(EMIT_ARGS)> get_emit_function(const Node& node);
// This defines a collection of function declarations like this
// static void emit_Abs(EMIT_ARGS);
// static void emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static void emit_##a(EMIT_ARGS);
// static std::string emit_Abs(EMIT_ARGS);
// static std::string emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static std::string emit_##a(EMIT_ARGS);
#include "ngraph/runtime/gpu/op/op_tbl.hpp"
#undef NGRAPH_OP
template <typename T>
static void emit_elementwise(EMIT_ARGS)
static std::string emit_elementwise(EMIT_ARGS)
{
if (out[0].get_size() == 0)
{
return;
return "";
}
else if (out.size() > 1)
{
......@@ -55,29 +54,22 @@ namespace ngraph
"Multi-output elementwise ops are not currently supported.");
}
auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter();
compiled_function->get_primitive_emitter()->get_cuda_emitter();
writer.block_begin();
std::vector<std::string> dtypes;
for (auto& arg : args)
{
std::vector<std::string> dtypes;
for (auto& arg : args)
{
dtypes.push_back(arg.get_type());
}
dtypes.push_back(out[0].get_type());
auto ew_index =
cuda_emitter->build_elementwise<T>(dtypes, out[0].get_shape());
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << ew_index
<< ", input, output);\n";
dtypes.push_back(arg.get_type());
}
writer.block_end();
dtypes.push_back(out[0].get_type());
auto ew_index = cuda_emitter->build_elementwise<T>(dtypes, out[0].get_shape());
return compiled_function->add_to_runtime(ew_index, function_name, args, out);
}
static void emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t);
static void emit_Sum_0(EMIT_ARGS);
static void emit_Sum_1(EMIT_ARGS);
static std::string emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t);
static std::string emit_Sum_0(EMIT_ARGS);
static std::string emit_Sum_1(EMIT_ARGS);
/// \brief Create a list of node names for each arg in args
/// \param args list of tensor arguments
......
......@@ -26,7 +26,6 @@
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
......@@ -103,39 +102,17 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_batch_norm_cache.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
using namespace ngraph;
static const string s_output_dir = "gpu_codegen";
static std::mutex s_compilation;
class GPUStaticInitializers
{
public:
GPUStaticInitializers()
{
file_util::remove_directory(s_output_dir);
file_util::make_directory(s_output_dir);
}
};
static string emit_string_array(const vector<string>& s, size_t max_line_length)
{
stringstream ss;
......@@ -168,36 +145,68 @@ static string emit_string_array(const vector<string>& s, size_t max_line_length)
return ss.str();
}
static GPUStaticInitializers s_static_initializers;
void runtime::gpu::GPU_ExternalFunction::emit_op(GPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<GPUTensorWrapper>& args,
const std::vector<GPUTensorWrapper>& out)
std::string runtime::gpu::GPUExternalFunction::emit_op(GPUCompiledFunction* external_function,
const std::string& function_name,
const ngraph::Node* node,
const std::vector<GPUTensorWrapper>& args,
const std::vector<GPUTensorWrapper>& out)
{
auto emit_function = GPU_Emitter::get_emit_function(*node);
emit_function(external_function, writer, node, args, out);
return emit_function(external_function, function_name, node, args, out);
};
const size_t runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction::s_memory_pool_alignment = 64;
runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
runtime::gpu::GPUExternalFunction::GPUExternalFunction(
const shared_ptr<ngraph::Function>& function,
std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
: m_compiled_function(nullptr)
, m_function(function)
, m_emit_timing(false)
, m_is_compiled(false)
, m_shared_context(shared_context)
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
: GPUCompiledFunction(function, shared_context)
{
}
runtime::gpu::GPU_ExternalFunction::~GPU_ExternalFunction()
runtime::gpu::GPUExternalFunction::~GPUExternalFunction()
{
}
const string& runtime::gpu::GPU_ExternalFunction::get_pch_header_source()
std::string runtime::gpu::GPUExternalFunction::add_to_runtime(
size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
codegen::CodeWriter writer;
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << primitive_index << ", input, output);\n";
writer.block_end();
}
return writer.get_code();
}
std::string runtime::gpu::GPUExternalFunction::add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
codegen::CodeWriter writer;
writer.block_begin();
{
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << callee << "(input, output, ctx);\n";
}
writer.block_end();
return writer.get_code();
}
std::string runtime::gpu::GPUExternalFunction::node_names(
const std::vector<runtime::gpu::GPUTensorWrapper>& args, std::initializer_list<int> arg_indexes)
{
return runtime::gpu::GPU_Emitter::node_names(args, arg_indexes);
}
const string& runtime::gpu::GPUExternalFunction::get_pch_header_source()
{
static string s_pch_header_source = R"(
// Generated by the nGraph GPU backend
......@@ -214,7 +223,7 @@ const string& runtime::gpu::GPU_ExternalFunction::get_pch_header_source()
return s_pch_header_source;
}
const string& runtime::gpu::GPU_ExternalFunction::get_header_source()
const string& runtime::gpu::GPUExternalFunction::get_header_source()
{
static string s_header_source =
get_pch_header_source() + R"(
......@@ -232,12 +241,12 @@ using namespace std;
return s_header_source;
}
void runtime::gpu::GPU_ExternalFunction::emit_header()
void runtime::gpu::GPUExternalFunction::emit_header()
{
m_writer << get_header_source();
}
void runtime::gpu::GPU_ExternalFunction::emit_timer_functions()
void runtime::gpu::GPUExternalFunction::emit_timer_functions()
{
if (m_emit_timing)
{
......@@ -296,7 +305,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_timer_functions()
}
}
void runtime::gpu::GPU_ExternalFunction::emit_constant_declarations()
void runtime::gpu::GPUExternalFunction::emit_constant_declarations()
{
m_writer << "// Declare all constants\n";
for (const auto& p : m_function_ordered_ops)
......@@ -348,7 +357,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_constant_declarations()
m_writer.block_end();
}
void runtime::gpu::GPU_ExternalFunction::emit_function_declarations()
void runtime::gpu::GPUExternalFunction::emit_function_declarations()
{
m_writer << "// Declare all functions\n";
for (const auto& p : m_function_ordered_ops)
......@@ -359,7 +368,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_function_declarations()
m_writer << "\n";
}
void runtime::gpu::GPU_ExternalFunction::emit_temp_mem_pool_allocation(
void runtime::gpu::GPUExternalFunction::emit_temp_mem_pool_allocation(
shared_ptr<Function> current_function)
{
bool temporaries_used = false;
......@@ -397,7 +406,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_temp_mem_pool_allocation(
}
}
void runtime::gpu::GPU_ExternalFunction::emit_functions()
void runtime::gpu::GPUExternalFunction::emit_functions()
{
for (const auto& p : m_function_ordered_ops)
{
......@@ -508,7 +517,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_functions()
auto it = m_node_function_map.find(node.get());
if (it == m_node_function_map.end())
{
emit_op(this, m_writer, node.get(), in, out);
m_writer << emit_op(this, current_function->get_name(), node.get(), in, out);
}
else
{
......@@ -538,75 +547,37 @@ void runtime::gpu::GPU_ExternalFunction::emit_functions()
}
}
void runtime::gpu::GPU_ExternalFunction::store_emitted_functions(const string& code)
void runtime::gpu::GPUExternalFunction::store_emitted_functions(const string& code)
{
// TODO: Cleanup and make this a utility function
string filename = file_util::path_join(s_output_dir, m_function_name + "_codegen.cpp");
string filename = file_util::path_join(get_output_dir(), m_function_name + "_codegen.cpp");
ofstream out(filename);
out << code;
out.close();
}
void runtime::gpu::GPU_ExternalFunction::compile()
void runtime::gpu::GPUExternalFunction::add_passes(ngraph::pass::Manager& pass_manager)
{
if (m_is_compiled)
{
return;
}
std::unique_lock<std::mutex> lock(s_compilation);
m_function_name = m_function->get_name();
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
ngraph::pass::Manager pass_manager;
#if CUDNN_VERSION >= 7200
// recurrent network fusion
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
#else
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
*allocator, m_tensor_memory_buffers);
std::string common_function_string;
auto femitter = bind(&ngraph::runtime::gpu::GPU_ExternalFunction::emit_op_as_function,
auto femitter = bind(&ngraph::runtime::gpu::GPUExternalFunction::emit_op_as_function,
this,
placeholders::_1,
placeholders::_2);
pass_manager.register_pass<ngraph::pass::CommonFunctionCollection>(
femitter, m_node_function_map, common_function_string);
string dump_filename = file_util::path_join(s_output_dir, m_function_name + "_ops.txt");
pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
pass_manager.run_passes(m_function);
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
m_function_ordered_ops.emplace(current_function, current_function->get_ordered_ops());
}
femitter, m_node_function_map, m_common_function_string);
}
void runtime::gpu::GPUExternalFunction::emit()
{
emit_header();
emit_timer_functions();
emit_constant_declarations();
emit_function_declarations();
m_writer << common_function_string << "\n";
m_writer << m_common_function_string << "\n";
emit_functions();
}
// allocate device buffers for primitive arguments and workspace
allocator->close();
m_shared_context->m_primitive_emitter->allocate_primitive_memory();
void runtime::gpu::GPUExternalFunction::compile_function()
{
string code = m_writer.get_code();
store_emitted_functions(code);
......@@ -623,16 +594,14 @@ void runtime::gpu::GPU_ExternalFunction::compile()
m_execution_engine->add_module(codegen_module);
m_execution_engine->finalize();
m_compiled_function = m_execution_engine->find_function<EntryPoint_t>(m_function_name);
if (!m_compiled_function)
m_runtime = m_execution_engine->find_function<EntryPoint_t>(m_function_name);
if (!m_runtime)
{
throw runtime_error("Function failed to compile");
}
m_is_compiled = true;
}
void runtime::gpu::GPU_ExternalFunction::emit_debug_function_entry(Node* node)
void runtime::gpu::GPUExternalFunction::emit_debug_function_entry(Node* node)
{
if (m_emit_timing)
{
......@@ -641,7 +610,7 @@ void runtime::gpu::GPU_ExternalFunction::emit_debug_function_entry(Node* node)
}
}
void runtime::gpu::GPU_ExternalFunction::emit_debug_function_exit(Node* node)
void runtime::gpu::GPUExternalFunction::emit_debug_function_exit(Node* node)
{
if (m_emit_timing)
{
......@@ -650,8 +619,8 @@ void runtime::gpu::GPU_ExternalFunction::emit_debug_function_exit(Node* node)
}
}
string runtime::gpu::GPU_ExternalFunction::emit_op_as_function(const Node& node,
const string& function_name)
string runtime::gpu::GPUExternalFunction::emit_op_as_function(const Node& node,
const string& function_name)
{
codegen::CodeWriter writer;
writer << "static void " << function_name << "(";
......@@ -692,9 +661,7 @@ string runtime::gpu::GPU_ExternalFunction::emit_op_as_function(const Node& node,
writer << ",\ngpu::GPURuntimeContext* ctx";
writer.indent--;
writer << "\n)\n";
codegen::CodeWriter tmp_writer;
emit_op(this, tmp_writer, &node, in, out);
string body = tmp_writer.get_code();
string body = emit_op(this, function_name, &node, in, out);
if (body.size() > 0 && body[0] == '{')
{
// Body already surrounded by curly braces so don't add more
......@@ -715,48 +682,8 @@ string runtime::gpu::GPU_ExternalFunction::emit_op_as_function(const Node& node,
return rc;
}
string runtime::gpu::GPU_ExternalFunction::strip_comments(const string& s) const
{
stringstream out;
for (size_t i = 0; i < s.size(); i++)
{
if (i < s.size() - 2)
{
if (s[i] == '/' && s[i + 1] == '/')
{
// line comment
i += 2;
while (s[i] != '\n')
{
i++;
}
out << '\n';
}
else if (s[i] == '/' && s[i + 1] == '*')
{
// multi-line comment
i += 2;
while (!(s[i] == '*' && s[i + 1] == '/'))
{
i++;
}
i++;
}
else
{
out << s[i];
}
}
else
{
out << s[i];
}
}
return out.str();
}
void runtime::gpu::GPU_ExternalFunction::propagate_in_place_input(
ngraph::descriptor::Output* output, std::string input_name)
void runtime::gpu::GPUExternalFunction::propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name)
{
std::deque<ngraph::descriptor::Output*> stack;
stack.push_front(output);
......@@ -795,8 +722,8 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_input(
}
}
void runtime::gpu::GPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name)
void runtime::gpu::GPUExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, const std::string& output_name)
{
// we start with a particular output
// which is an argument to a given op::Result
......@@ -837,3 +764,66 @@ void runtime::gpu::GPU_ExternalFunction::propagate_in_place_output(
}
} while (propagate_further);
}
string runtime::gpu::GPUExternalFunction::strip_comments(const string& s) const
{
stringstream out;
for (size_t i = 0; i < s.size(); i++)
{
if (i < s.size() - 2)
{
if (s[i] == '/' && s[i + 1] == '/')
{
// line comment
i += 2;
while (s[i] != '\n')
{
i++;
}
out << '\n';
}
else if (s[i] == '/' && s[i + 1] == '*')
{
// multi-line comment
i += 2;
while (!(s[i] == '*' && s[i + 1] == '/'))
{
i++;
}
i++;
}
else
{
out << s[i];
}
}
else
{
out << s[i];
}
}
return out.str();
}
void runtime::gpu::GPUExternalFunction::get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const
{
auto* engine = this->m_execution_engine.get();
if (engine)
{
auto get_count = engine->find_function<size_t()>("get_debug_timer_count");
auto get_name = engine->find_function<const char*(size_t)>("get_debug_timer_name");
auto get_microseconds =
engine->find_function<size_t(size_t)>("get_debug_timer_microseconds");
auto get_call_count = engine->find_function<size_t(size_t)>("get_debug_timer_call_count");
if (get_count && get_name && get_microseconds && get_call_count)
{
size_t count = get_count();
for (size_t i = 0; i < count; i++)
{
rc.push_back({get_name(i), get_microseconds(i), get_call_count(i)});
}
}
}
}
......@@ -16,6 +16,8 @@
#pragma once
#if !defined(NGRAPH_DEX_ONLY)
#include <functional>
#include <memory>
#include <typeindex>
......@@ -32,14 +34,10 @@
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#define EMIT_ARGS \
runtime::gpu::GPU_ExternalFunction *external_function, codegen::CodeWriter &writer, \
const Node *node, const std::vector<runtime::gpu::GPUTensorWrapper> &args, \
const std::vector<runtime::gpu::GPUTensorWrapper> &out
namespace ngraph
{
namespace runtime
......@@ -49,37 +47,41 @@ namespace ngraph
class GPU_Emitter;
struct GPURuntimeContext;
class GPU_ExternalFunction
class GPUExternalFunction : public GPUCompiledFunction
{
friend class GPU_Backend;
public:
GPU_ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
~GPU_ExternalFunction();
std::unique_ptr<runtime::gpu::GPURuntimeContext>& ctx();
const std::unique_ptr<GPUPrimitiveEmitter>& get_primitive_emitter() const
{
return m_shared_context->m_primitive_emitter;
}
static const size_t s_memory_pool_alignment;
GPUExternalFunction(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUExternalFunction();
virtual std::string
add_to_runtime(size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual std::string add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual void get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const override;
protected:
void compile();
EntryPoint m_compiled_function;
virtual void compile_function() override;
virtual void add_passes(ngraph::pass::Manager& pass_manager) override;
virtual void emit() override;
private:
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
std::string input_name);
// For in-place kernels, propagate function output buffers to
// internal ops
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
std::string output_name);
/// \brief Create a list of node names for each arg in args
/// \param args list of tensor arguments
/// \param arg_indexes a list of indexes into args for which args to include in
/// the output list, so {1, 2} will include args 1 and 2 and skip 0.
/// \ return returns a string containing "arg0_name, arg1_name, etc."
std::string node_names(const std::vector<runtime::gpu::GPUTensorWrapper>& args,
std::initializer_list<int> arg_indexes = {});
void emit_header();
void emit_timer_functions();
void emit_constant_declarations();
......@@ -88,35 +90,31 @@ namespace ngraph
void emit_debug_function_entry(Node* node);
void emit_debug_function_exit(Node* node);
void emit_temp_mem_pool_allocation(std::shared_ptr<Function> current_function);
void emit_op(EMIT_ARGS);
void store_emitted_functions(const std::string& code);
std::string emit_op(EMIT_ARGS);
std::string emit_op_as_function(const Node& node, const std::string& function_name);
std::string strip_comments(const std::string& s) const;
static const std::string& get_pch_header_source();
static const std::string& get_header_source();
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) override;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) override;
codegen::CodeWriter m_writer;
std::string m_common_function_string;
std::unique_ptr<codegen::Compiler> m_compiler;
std::unique_ptr<codegen::ExecutionEngine> m_execution_engine;
std::shared_ptr<ngraph::Function> m_function;
std::map<std::string, size_t> m_name_index_map;
std::unordered_map<std::string, std::string> m_variable_name_map;
std::unordered_map<Node*, Node*> m_node_function_map;
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>
m_function_ordered_ops;
bool m_emit_timing;
bool m_is_compiled;
size_t m_offset;
std::string m_function_name;
std::unordered_map<std::string, size_t> m_tensor_memory_buffers;
std::shared_ptr<GPU_Backend::BackendContext> m_shared_context;
};
}
}
}
#endif // !defined(NGRAPH_DEX_ONLY)
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <fstream>
#include <mutex>
#include <string>
#include <tuple>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/function_call.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/sigmoid.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/common_function_collection.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_constructor.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp"
using namespace std;
using namespace ngraph;
std::string runtime::gpu::GPUInternalFunction::emit_op(GPUCompiledFunction* compiled_function,
const std::string& function_name,
const ngraph::Node* node,
const std::vector<GPUTensorWrapper>& args,
const std::vector<GPUTensorWrapper>& out)
{
auto emit_function = GPU_Emitter::get_emit_function(*node);
return emit_function(compiled_function, function_name, node, args, out);
};
runtime::gpu::GPUInternalFunction::GPUInternalFunction(
const shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
: GPUCompiledFunction(function, shared_context)
{
}
runtime::gpu::GPUInternalFunction::~GPUInternalFunction()
{
if (m_trace)
{
string filename = file_util::path_join(get_output_dir(), m_function_name + "_trace.txt");
ofstream out(filename);
out << m_trace->get_code();
out.close();
}
}
std::string runtime::gpu::GPUInternalFunction::add_to_runtime(
size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
std::function<void(GPUCallFrame & call_frame, GPURuntimeContext * ctx)> primitive_invocation;
if (!m_trace)
{
primitive_invocation = [args, out, primitive_index](GPUCallFrame& call_frame,
GPURuntimeContext* ctx) mutable {
// here, these inputs and outputs could be any of [constant, input, output, intermediate]
auto inputs = call_frame.get_tensor_io(args);
auto outputs = call_frame.get_tensor_io(out);
runtime::gpu::invoke_primitive(ctx, primitive_index, inputs.data(), outputs.data());
};
}
else
{
primitive_invocation = [this, args, out, primitive_index](GPUCallFrame& call_frame,
GPURuntimeContext* ctx) mutable {
// here, these inputs and outputs could be any of [constant, input, output, intermediate]
auto inputs = call_frame.get_tensor_io(args);
auto outputs = call_frame.get_tensor_io(out);
*m_trace << "(";
for (size_t i = 0; i < outputs.size(); i++)
{
if (i != 0)
{
*m_trace << ", ";
}
*m_trace << std::hex << outputs[i];
}
*m_trace << ") = primitive(" << primitive_index << ", ";
for (size_t i = 0; i < inputs.size(); i++)
{
if (i != 0)
{
*m_trace << ", ";
}
*m_trace << std::hex << inputs[i];
}
*m_trace << ");\n";
*m_trace << compose_manifest(primitive_index, args, out);
runtime::gpu::invoke_primitive(ctx, primitive_index, inputs.data(), outputs.data());
};
}
m_runtime_constructor->add(function_name, primitive_invocation);
return compose_manifest(primitive_index, args, out);
}
std::string runtime::gpu::GPUInternalFunction::add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
m_runtime_constructor->add_call(caller, callee, args, out);
codegen::CodeWriter writer;
writer.block_begin();
{
for (auto const& tensor : args)
{
writer << "push " << tensor << "\n";
}
writer << "call " << callee << "\n";
for (auto const& tensor : out)
{
writer << "pop " << tensor << "\n";
}
}
writer.block_end();
return writer.get_code();
}
std::string runtime::gpu::GPUInternalFunction::compose_manifest(
size_t primitive_index,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) const
{
codegen::CodeWriter writer;
writer.block_begin();
{
for (auto const& tensor : args)
{
writer << "push " << tensor << "\n";
}
writer << "call primitive(" << primitive_index << ")\n";
for (auto const& tensor : out)
{
writer << "pop " << tensor << "\n";
}
}
writer.block_end();
return writer.get_code();
}
void runtime::gpu::GPUInternalFunction::build_functions()
{
for (const auto& p : m_function_ordered_ops)
{
auto& current_function = p.first;
// Add inputs to the variable name map
size_t arg_index = 0;
for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
shared_ptr<descriptor::Tensor> tv = param->get_output_tensor_ptr(i);
const element::Type& et = tv->get_element_type();
string type = et.c_type_string();
stringstream ss;
ss << "((" << type << "*)(inputs[" << arg_index << "]))";
m_variable_name_map[tv->get_name()] = std::make_tuple(
runtime::gpu::GPUTensorWrapper::TensorType::INPUT, arg_index, ss.str());
// propagate_in_place_input(&param->get_outputs().at(i), ss.str());
arg_index++;
}
}
// Add outputs to the variable name map
for (size_t i = 0; i < current_function->get_output_size(); ++i)
{
shared_ptr<Node> op = current_function->get_output_op(i);
shared_ptr<descriptor::Tensor> tv = op->get_output_tensor_ptr();
string type = tv->get_element_type().c_type_string();
stringstream ss;
ss << "((" << type << "*)(outputs[" << i << "]))";
m_variable_name_map[tv->get_name()] =
std::make_tuple(runtime::gpu::GPUTensorWrapper::TensorType::OUTPUT, i, ss.str());
auto res = dynamic_pointer_cast<ngraph::op::Result>(op);
//keep assigning different outputs to a result descriptor
//op::Result emitter will check if in and out descriptors are the same
//and skip a copy
auto input_node = res->get_inputs().at(0).get_output().get_node();
if (!input_node->is_constant() && !input_node->is_parameter())
{
shared_ptr<descriptor::Tensor> itv =
res->get_inputs().at(0).get_output().get_tensor_ptr();
auto output_name = ss.str();
m_variable_name_map[itv->get_name()] = std::make_tuple(
runtime::gpu::GPUTensorWrapper::TensorType::OUTPUT, i, ss.str());
//propagate_in_place_output(&(res->get_inputs().at(0).get_output()), output_name);
}
}
// Add temporaries to the variable name map
bool temporaries_used = false;
for (shared_ptr<Node> node : m_function_ordered_ops.at(current_function))
{
if (node->liveness_new_list.size() > 0)
{
temporaries_used = true;
break;
}
}
if (temporaries_used)
{
for (shared_ptr<Node> node : m_function_ordered_ops.at(current_function))
{
for (descriptor::Tensor* tensor : node->liveness_new_list)
{
m_variable_name_map[tensor->get_name()] =
std::make_tuple(runtime::gpu::GPUTensorWrapper::TensorType::INTERMEDIATE,
tensor->get_pool_offset(),
current_function->get_name());
}
}
}
// Add constants to the variable name map
for (shared_ptr<Node> node : p.second)
{
if (auto c = std::dynamic_pointer_cast<op::Constant>(node))
{
shared_ptr<descriptor::Tensor> tv = node->get_outputs()[0].get_tensor_ptr();
m_variable_name_map[tv->get_name()] = std::make_tuple(
runtime::gpu::GPUTensorWrapper::TensorType::CONSTANT, 0, node->get_name());
}
}
for (shared_ptr<Node> node : m_function_ordered_ops.at(current_function))
{
vector<string> node_input_names;
vector<string> node_output_names;
vector<GPUTensorWrapper> in;
for (const descriptor::Input& input : node->get_inputs())
{
const descriptor::Output& output = input.get_output();
shared_ptr<descriptor::Tensor> tv = output.get_tensor_ptr();
auto& var = m_variable_name_map[tv->get_name()];
in.push_back(
GPUTensorWrapper(tv, std::get<0>(var), std::get<1>(var), std::get<2>(var)));
node_input_names.emplace_back(tv->get_name());
}
vector<GPUTensorWrapper> out;
for (const descriptor::Output& output : node->get_outputs())
{
shared_ptr<descriptor::Tensor> tv = output.get_tensor_ptr();
auto& var = m_variable_name_map[tv->get_name()];
out.push_back(
GPUTensorWrapper(tv, std::get<0>(var), std::get<1>(var), std::get<2>(var)));
node_output_names.emplace_back(tv->get_name());
}
// Emit function description comment
if (!node->is_parameter() && !node->is_constant())
{
m_manifest << "\n// " << current_function->get_name() << "::" << node->get_name()
<< "(";
vector<string> parameter_nodes = node_input_names;
parameter_nodes.insert(
parameter_nodes.end(), node_output_names.begin(), node_output_names.end());
m_manifest << join(parameter_nodes);
m_manifest << ")\n";
// emit_debug_function_entry(node.get());
}
// Emit operation body
// m_writer << emit_op(this, node.get(), in, out);
m_manifest << emit_op(this, current_function->get_name(), node.get(), in, out);
// Emit operation epilogue
// if (!node->is_parameter() && !node->is_constant())
// {
// emit_debug_function_exit(node.get());
// }
}
}
}
void runtime::gpu::GPUInternalFunction::add_passes(ngraph::pass::Manager& pass_manager)
{
}
void runtime::gpu::GPUInternalFunction::emit()
{
m_runtime_constructor =
runtime::gpu::make_unique<GPURuntimeConstructor>(m_function_ordered_ops);
if (std::getenv("NGRAPH_GPU_TRACE"))
{
m_trace = std::make_shared<codegen::CodeWriter>();
}
// build and emit functions
build_functions();
}
void runtime::gpu::GPUInternalFunction::compile_function()
{
GPUCallFrame call_frame(m_function->get_parameters().size(), m_function->get_output_size());
// resolve memory reservations (constants and intermediate buffers)
call_frame.resolve_reservations(this, m_tensor_memory_buffers);
// build runtime
m_runtime = m_runtime_constructor->build(m_function_name, call_frame);
// store manifest
save_manifest_to_disk();
m_is_compiled = true;
}
void runtime::gpu::GPUInternalFunction::save_manifest_to_disk() const
{
string filename = file_util::path_join(get_output_dir(), m_function_name + "_manifest.txt");
ofstream out(filename);
out << m_manifest.get_code();
out.close();
}
void runtime::gpu::GPUInternalFunction::propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name)
{
// std::deque<ngraph::descriptor::Output*> stack;
// stack.push_front(output);
// while (stack.size() > 0)
// {
// ngraph::descriptor::Output* it = stack.front();
// stack.pop_front();
// for (auto input : it->get_inputs())
// {
// auto c_op = std::dynamic_pointer_cast<ngraph::op::Op>(input->get_node());
// if (!c_op || c_op->is_output())
// {
// continue;
// }
// if (auto op_annotations = c_op->get_op_annotations())
// {
// for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
// {
// if (oi_pair.input == input->get_index() && !oi_pair.destructive)
// {
// size_t output_index = oi_pair.output;
// auto& output_tensor = c_op->get_outputs().at(output_index).get_tensor();
// m_variable_name_map[output_tensor.get_name()] = input_name;
// NGRAPH_DEBUG << "GPU codegen: Forwarding " << input_name << " through "
// << output_tensor.get_name();
// stack.push_back(&c_op->get_outputs().at(output_index));
// }
// }
// }
// }
// }
}
void runtime::gpu::GPUInternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, const std::string& output_name)
{
// // we start with a particular output
// // which is an argument to a given op::Result
// size_t offset = res_src_output->get_tensor().get_pool_offset();
// auto it = res_src_output;
// bool propagate_further = false;
// do
// {
// propagate_further = false;
// auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node());
// if (!arg)
// {
// break;
// }
// if (auto op_annotations = arg->get_op_annotations())
// {
// for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
// {
// if (oi_pair.output == it->get_index())
// {
// size_t input_index = oi_pair.input;
// auto& input_tensor = arg->get_inputs().at(input_index).get_tensor();
// auto tmp_node = arg->get_inputs().at(input_index).get_output().get_node();
// if (input_tensor.get_pool_offset() == offset && !tmp_node->is_parameter() &&
// !tmp_node->is_constant())
// {
// NGRAPH_DEBUG << "Reusing " << output_name << " for "
// << input_tensor.get_name();
// m_variable_name_map[input_tensor.get_name()] = output_name;
// it = &arg->get_inputs().at(input_index).get_output();
// propagate_further = true;
// }
// }
// }
// }
// } while (propagate_further);
}
void runtime::gpu::GPUInternalFunction::get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const
{
// auto* engine = this->m_execution_engine.get();
// if (engine)
// {
// auto get_count = engine->find_function<size_t()>("get_debug_timer_count");
// auto get_name = engine->find_function<const char*(size_t)>("get_debug_timer_name");
// auto get_microseconds =
// engine->find_function<size_t(size_t)>("get_debug_timer_microseconds");
// auto get_call_count =
// engine->find_function<size_t(size_t)>("get_debug_timer_call_count");
// if (get_count && get_name && get_microseconds && get_call_count)
// {
// size_t count = get_count();
// for (size_t i = 0; i < count; i++)
// {
// rc.push_back({get_name(i), get_microseconds(i), get_call_count(i)});
// }
// }
// }
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <tuple>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_Emitter;
class GPURuntimeConstructor;
struct GPURuntimeContext;
class GPUInternalFunction : public GPUCompiledFunction
{
public:
GPUInternalFunction(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUInternalFunction();
virtual std::string
add_to_runtime(size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual std::string add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual void get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const override;
protected:
virtual void compile_function() override;
virtual void add_passes(ngraph::pass::Manager& pass_manager) override;
virtual void emit() override;
private:
void build_functions();
std::string emit_op(EMIT_ARGS);
std::string
compose_manifest(size_t primitive_index,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) const;
void save_manifest_to_disk() const;
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) override;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) override;
std::unordered_map<
std::string,
std::tuple<runtime::gpu::GPUTensorWrapper::TensorType, size_t, std::string>>
m_variable_name_map;
std::unique_ptr<GPURuntimeConstructor> m_runtime_constructor;
std::shared_ptr<codegen::CodeWriter> m_trace;
codegen::CodeWriter m_manifest;
};
}
}
}
......@@ -24,6 +24,7 @@ using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_host_emitter(new HostEmitter(this, nullptr))
, m_cuda_emitter(new CUDAEmitter(this, nullptr, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
, m_cublas_emitter(new CUBLASEmitter(this, nullptr))
......@@ -33,6 +34,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx)
: m_memory_manager(this)
, m_host_parameters(new GPUHostParameters)
, m_host_emitter(new HostEmitter(this, ctx.get()))
, m_cuda_emitter(new CUDAEmitter(this, ctx.get(), this->m_host_parameters))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
, m_cublas_emitter(new CUBLASEmitter(this, ctx.get()))
......@@ -40,6 +42,10 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext
{
}
std::unique_ptr<HostEmitter>& GPUPrimitiveEmitter::get_host_emitter()
{
return m_host_emitter;
}
std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter()
{
return m_cuda_emitter;
......@@ -48,7 +54,6 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{
return m_cudnn_emitter;
}
std::unique_ptr<CUBLASEmitter>& GPUPrimitiveEmitter::get_cublas_emitter()
{
return m_cublas_emitter;
......
......@@ -22,6 +22,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_args.hpp"
#include "ngraph/runtime/gpu/gpu_memory_manager.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/host_emitter.hpp"
namespace ngraph
{
......@@ -34,6 +35,7 @@ namespace ngraph
public:
GPUPrimitiveEmitter();
GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx);
std::unique_ptr<HostEmitter>& get_host_emitter();
std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::unique_ptr<CUBLASEmitter>& get_cublas_emitter();
......@@ -59,6 +61,7 @@ namespace ngraph
std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives;
GPUMemoryManager m_memory_manager;
std::shared_ptr<GPUHostParameters> m_host_parameters;
std::unique_ptr<HostEmitter> m_host_emitter;
std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::unique_ptr<CUBLASEmitter> m_cublas_emitter;
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/gpu_runtime_constructor.hpp"
using namespace ngraph;
runtime::gpu::GPURuntimeConstructor::GPURuntimeConstructor(const op_order_t& ordered_ops)
{
for (auto const& ops : ordered_ops)
{
m_runtime[ops.first->get_name()].reserve(ops.second.size());
}
}
void runtime::gpu::GPURuntimeConstructor::add(const std::string& name, const op_runtime_t& step)
{
m_runtime[name].push_back(step);
}
void runtime::gpu::GPURuntimeConstructor::add_call(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
auto& runtime = m_runtime[callee];
auto call = [args, out, &runtime](GPUCallFrame& caller_frame, GPURuntimeContext* ctx) mutable {
// extract memory pointers from the callers stack
auto inputs = caller_frame.get_tensor_io(args);
auto outputs = caller_frame.get_tensor_io(out);
// create a new call frame for the nested function
GPUCallFrame callee_frame = caller_frame;
// resolve the inputs of the new call frame
callee_frame.resolve_inputs(inputs.data(), inputs.size());
callee_frame.resolve_outputs(outputs.data(), outputs.size());
for (auto const& step : runtime)
{
step(callee_frame, ctx);
}
};
add(caller, call);
}
runtime::gpu::EntryPoint runtime::gpu::GPURuntimeConstructor::build(const std::string& function,
GPUCallFrame& call_frame)
{
auto& runtime = m_runtime.at(function);
return [call_frame, &runtime](void** inputs, void** outputs, GPURuntimeContext* ctx) mutable {
call_frame.resolve_inputs(inputs);
call_frame.resolve_outputs(outputs);
for (auto const& step : runtime)
{
step(call_frame, ctx);
}
};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUCallFrame;
class GPURuntimeConstructor
{
public:
using op_runtime_t =
std::function<void(GPUCallFrame& call_frame, GPURuntimeContext* ctx)>;
using op_order_t =
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>;
GPURuntimeConstructor(const op_order_t& ordered_ops);
void add(const std::string& name, const op_runtime_t& step);
void add_call(const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out);
EntryPoint build(const std::string& function, GPUCallFrame& call_frame);
private:
std::unordered_map<std::string, std::vector<op_runtime_t>> m_runtime;
};
}
}
}
......@@ -13,10 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <limits>
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
using namespace std;
using namespace ngraph;
......@@ -25,6 +26,18 @@ runtime::gpu::GPUTensorWrapper::GPUTensorWrapper(const shared_ptr<descriptor::Te
const string& alias)
: m_tensor(tv)
, m_alias(alias)
, m_offset(std::make_pair(runtime::gpu::GPUTensorWrapper::TensorType::UNKNOWN,
std::numeric_limits<size_t>::max()))
{
}
runtime::gpu::GPUTensorWrapper::GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>& tv,
runtime::gpu::GPUTensorWrapper::TensorType type,
size_t offset,
const std::string& alias)
: m_tensor(tv)
, m_alias(alias)
, m_offset(std::make_pair(type, offset))
{
}
......@@ -60,7 +73,25 @@ const std::string& runtime::gpu::GPUTensorWrapper::get_name() const
}
}
const std::pair<runtime::gpu::GPUTensorWrapper::TensorType, size_t>&
runtime::gpu::GPUTensorWrapper::get_offset() const
{
return m_offset;
}
const std::string& runtime::gpu::GPUTensorWrapper::get_type() const
{
return get_element_type().c_type_string();
}
std::ostream& ngraph::runtime::gpu::operator<<(std::ostream& out,
const ngraph::runtime::gpu::GPUTensorWrapper& obj)
{
static std::vector<std::string> types{"CONSTANT", "INTERMEDIATE", "INPUT", "OUTPUT", "UNKNOWN"};
out << "gpu::tensor { name: " << obj.m_tensor->get_name()
<< " tensor_type: " << types.at(static_cast<size_t>(obj.m_offset.first))
<< ", offset/index: " << obj.m_offset.second << ", dtype: " << obj.get_element_type()
<< ", shape: " << obj.get_shape() << ", size: " << obj.get_size()
<< ", alias: " << obj.m_alias << " }";
return out;
}
......@@ -17,6 +17,7 @@
#pragma once
#include <memory>
#include <tuple>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -28,6 +29,8 @@ namespace ngraph
namespace gpu
{
class GPUTensorWrapper;
std::ostream& operator<<(std::ostream& out,
const ngraph::runtime::gpu::GPUTensorWrapper& obj);
}
}
}
......@@ -35,7 +38,19 @@ namespace ngraph
class ngraph::runtime::gpu::GPUTensorWrapper
{
public:
enum TensorType : std::size_t
{
CONSTANT,
INTERMEDIATE,
INPUT,
OUTPUT,
UNKNOWN
};
GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>&, const std::string& alias = "");
GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>&,
TensorType,
size_t,
const std::string& alias);
size_t get_size() const;
const Shape& get_shape() const;
......@@ -43,8 +58,12 @@ public:
const element::Type& get_element_type() const;
const std::string& get_name() const;
const std::string& get_type() const;
const std::pair<TensorType, size_t>& get_offset() const;
friend std::ostream& ngraph::runtime::gpu::
operator<<(std::ostream& out, const ngraph::runtime::gpu::GPUTensorWrapper& obj);
private:
std::shared_ptr<descriptor::Tensor> m_tensor;
std::string m_alias;
std::pair<TensorType, size_t> m_offset;
};
......@@ -17,6 +17,7 @@
#pragma once
#include <cudnn.h>
#include <memory>
#include <vector>
namespace ngraph
......@@ -37,6 +38,12 @@ namespace ngraph
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
uint32_t idiv_ceil(int n, int d);
template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args)
{
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
// This is commented out because it increases the compile time.
// It should be moved to a debug header.
// template <typename T>
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include <sstream>
#include <vector>
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/host_emitter.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
runtime::gpu::HostEmitter::HostEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx)
: m_primitive_emitter(emitter)
, m_ctx(ctx)
{
}
size_t runtime::gpu::HostEmitter::build_memcpy(const cudaMemcpyKind& kind,
size_t size,
size_t dst,
size_t src)
{
std::stringstream ss;
ss << "memcpy" << kind << "_dst" << dst << "_src" << src << "_sz" << size;
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::unique_ptr<gpu::primitive> launch_kernel(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
CUDA_RT_SAFE_CALL(cudaMemcpy(outputs[dst], inputs[src], size, kind));
}});
return this->m_primitive_emitter->register_primitive(launch_kernel, hash);
}
size_t runtime::gpu::HostEmitter::build_zero_out(size_t dst, size_t size, bool is_local)
{
std::stringstream ss;
ss << "zero"
<< "_dst" << dst << "_sz" << size << "_local" << is_local;
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::unique_ptr<gpu::primitive> launch_kernel;
if (is_local)
{
launch_kernel.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* tensor = gpu::invoke_memory_primitive(m_ctx, dst);
CUDA_RT_SAFE_CALL(cudaMemset(tensor, 0, size));
}});
}
else
{
launch_kernel.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
CUDA_RT_SAFE_CALL(cudaMemset(outputs[dst], 0, size));
}});
}
return this->m_primitive_emitter->register_primitive(launch_kernel, hash);
}
......@@ -14,34 +14,39 @@
// limitations under the License.
//*****************************************************************************
#include "code_writer.hpp"
#pragma once
using namespace std;
using namespace ngraph;
#include <functional>
#include <memory>
#include <vector>
codegen::CodeWriter::CodeWriter()
: indent(0)
, m_pending_indent(true)
, m_temporary_name_count(0)
{
}
string codegen::CodeWriter::get_code() const
{
return m_ss.str();
}
#include <cuda_runtime.h>
void codegen::CodeWriter::operator+=(const std::string& s)
namespace ngraph
{
*this << s;
}
std::string codegen::CodeWriter::generate_temporary_name(std::string prefix)
{
std::stringstream ss;
ss << prefix << m_temporary_name_count;
m_temporary_name_count++;
return ss.str();
namespace runtime
{
namespace gpu
{
struct GPURuntimeContext;
class GPUPrimitiveEmitter;
class HostEmitter
{
friend class GPUPrimitiveEmitter;
public:
size_t build_memcpy(const cudaMemcpyKind& kind,
size_t size,
size_t dst = 0,
size_t src = 0);
size_t build_zero_out(size_t dst, size_t size, bool is_local = false);
private:
HostEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
GPUPrimitiveEmitter* m_primitive_emitter;
GPURuntimeContext* m_ctx;
};
}
}
}
......@@ -197,7 +197,7 @@ bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::share
auto handler = s_dispatcher.find(TI(n));
if (handler != s_dispatcher.end())
{
handler->second(m_external_function, node);
handler->second(m_compiled_function, node);
}
}
......
......@@ -17,10 +17,10 @@
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#define LAYOUT_DECL(op_type) \
layout<op_type>(ngraph::runtime::gpu::GPU_ExternalFunction * external_function, \
layout<op_type>(ngraph::runtime::gpu::GPUCompiledFunction * compiled_function, \
std::shared_ptr<ngraph::Node> node)
namespace ngraph
......@@ -32,27 +32,26 @@ namespace ngraph
namespace pass
{
using LayoutFunction =
std::function<void(GPU_ExternalFunction*, std::shared_ptr<ngraph::Node>)>;
std::function<void(GPUCompiledFunction*, std::shared_ptr<ngraph::Node>)>;
using LayoutOpMap = std::unordered_map<std::type_index, LayoutFunction>;
class GPULayout : public ngraph::pass::CallGraphPass
{
public:
GPULayout(GPU_ExternalFunction* external_function)
: m_external_function(external_function)
GPULayout(GPUCompiledFunction* compiled_function)
: m_compiled_function(compiled_function)
{
}
virtual bool
run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
template <typename OP>
static void
layout(ngraph::runtime::gpu::GPU_ExternalFunction* external_function,
std::shared_ptr<ngraph::Node> node);
static void layout(ngraph::runtime::gpu::GPUCompiledFunction* compiled_function,
std::shared_ptr<ngraph::Node> node);
private:
GPU_ExternalFunction* m_external_function;
GPUCompiledFunction* m_compiled_function;
};
NodeVector insert_new_reshape_after(NodeVector& parents,
......
......@@ -19,6 +19,7 @@
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/runtime/gpu/gpu_memory_manager.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
......@@ -28,13 +29,27 @@ using namespace std;
bool runtime::gpu::pass::TensorMemoryReservation::run_on_function(shared_ptr<Function> f)
{
bool reservation = false;
size_t mem_pool_size = f->get_temporary_pool_size();
// intermediate memory reservation
if (mem_pool_size)
{
size_t pool_idx = m_allocator.reserve_workspace(mem_pool_size, false);
m_memory_buffers.insert({f->get_name(), pool_idx});
reservation = true;
}
return true;
// constant memory reservation
for (auto const& node : f->get_ops())
{
if (auto constant = std::dynamic_pointer_cast<ngraph::op::Constant>(node))
{
std::shared_ptr<descriptor::Tensor> tv = node->get_outputs()[0].get_tensor_ptr();
size_t idx = m_allocator.reserve_argspace(constant->get_data_ptr(), tv->size());
m_memory_buffers.insert({node->get_name(), idx});
reservation = true;
}
}
return false;
return reservation;
}
......@@ -57,3 +57,17 @@ shape_of_scalar
shape_of_vector
shape_of_matrix
shape_of_5d
# zero size axis needs to be implemented
# differently in gpu_emitter.cpp
# these should be re-enabled before
# merging to master
product_matrix_rows_zero
product_matrix_cols_zero
product_vector_zero
product_matrix_to_scalar_zero_by_zero
product_3d_eliminate_zero_dim
max_matrix_rows_zero_int32
reduce_matrix_rows_zero
reduce_matrix_cols_zero
reduce_vector_zero
reduce_matrix_to_scalar_zero_by_zero
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment