Commit 757621be authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

nvgpu backend without clang (#2115)

* Separate out external function base class.

* pt1 first step to removing m_writer from GPU_Emitter.

* pt2 add gpu_internal function skeleton

* pt3 temporarily add to gpu_backend for prototyping.

* pt4 add call frame (partial) and runtime constructor

* pt 5 implement resolution for function memory reservations. build new tensor wrapper for use with call frame.

* pt 6 resolve compilation errors.

* pt 7 Add host emitter for emitting host primtives and implement in gpu emitter.

* pt 8 add compile time manifest.

* pt 9 add simple runtime tracer.

* pt 10 seperate runtimes for different functions. index by function name, should switch to using function instance_id for look up performance.

* pt 11 add function call interface and support nested call frames

* pt 12 Reshape elimination check in emitter needs to include offset.

* pt 13 Add default indentation to all op emissions in gpu external functions.

* pt 14 fix constant mem reservation (should not depend on the tmeporary buffers existence check.

* pt 15 backward pooling for avg pool requires only one param. rather than passing this param
three times, this commit changes the runtime to detect if its avgpooling and pass the appropriate pointers.
This is a hold over until max and avgpool are refactored into separate cudnn emitters.

* pt 16 update cmake compatibility. gpu backend can now be built without clang via NGRAPH_DEX_ONLY.
if this cmake variable is not define, then both clang codegen (via gpu external function) and interpreter (via gpu internal function) modes will be built.
for now codegen is the default backend but can be explicitly disabled by setting the env. variable to NGRAPH_CODEGEN=0/FALSE/NO/etc.

additional note: made codegen::CodeWriter header-only so that it can be used independently of whether the clang codegen library is compiled.

* pt 17 fix issues with merge from master

* pt 18 factor compile function into a few virtual calls so that common passes can be added in a single location for both backends.

* pt 19 formatting

* Remove code_writer.cpp from cmake and disable (temporarily) some reduce tests that require changes to gpu_emitter.cpp

* Move call frame and runtime constructor implementations to source files.

* Use member m_common_function_string.

* Applying analogous bug fix as found in #2145

* Remove underscore from GPU_CompiledFunction, GPU_ExternalFunction, and GPU_InternalFunction.

* Made static members of GPUCompiledFunction static methods.

* Remove 'No' codegen options, use std::toupper and applied format

* review comments

* Remove vector overload for resolve inputs/outputs in GPUCallFrame.

* Remove diagnostic pragmas
parent 84d6ae08
...@@ -108,7 +108,7 @@ option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE ...@@ -108,7 +108,7 @@ option(NGRAPH_GPUH_ENABLE "Control the building of the Hybrid GPU backend" FALSE
option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE) option(NGRAPH_DISTRIBUTED_ENABLE "Add distributed mode to the CPU backend" FALSE)
option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE) option(NGRAPH_DEBUG_ENABLE "Enable output for NGRAPH_DEBUG statements" FALSE)
option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE) option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" FALSE)
option(NGRAPH_DEX_ONLY "Build CPU DEX without codegen" FALSE) option(NGRAPH_DEX_ONLY "Build CPU or GPU DEX without codegen" FALSE)
option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE) option(NGRAPH_CODE_COVERAGE_ENABLE "Enable code coverage data collection" FALSE)
option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE) option(NGRAPH_LIB_VERSIONING_ENABLE "Enable shared library versioning" FALSE)
option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE) option(NGRAPH_PYTHON_BUILD_ENABLE "Enable build nGraph python package wheel" FALSE)
......
...@@ -15,13 +15,12 @@ ...@@ -15,13 +15,12 @@
# ****************************************************************************** # ******************************************************************************
if (NGRAPH_GPU_ENABLE OR (NGRAPH_CPU_ENABLE AND NOT NGRAPH_DEX_ONLY)) if ((NGRAPH_GPU_ENABLE OR NGRAPH_CPU_ENABLE) AND NOT NGRAPH_DEX_ONLY)
if (DEFINED NGRAPH_USE_CXX_ABI AND NGRAPH_USE_PREBUILT_LLVM) if (DEFINED NGRAPH_USE_CXX_ABI AND NGRAPH_USE_PREBUILT_LLVM)
message(FATAL_ERROR "Unable to use NGRAPH_USE_PREBUILT_LLVM with NGRAPH_USE_CXX_ABI") message(FATAL_ERROR "Unable to use NGRAPH_USE_PREBUILT_LLVM with NGRAPH_USE_CXX_ABI")
endif() endif()
set(SRC set(SRC
code_writer.cpp
compiler.cpp compiler.cpp
execution_engine.cpp execution_engine.cpp
) )
......
...@@ -30,13 +30,14 @@ namespace ngraph ...@@ -30,13 +30,14 @@ namespace ngraph
class ngraph::codegen::CodeWriter class ngraph::codegen::CodeWriter
{ {
public: public:
CodeWriter(); CodeWriter()
std::string get_code() const; : indent(0)
, m_pending_indent(true)
void operator+=(const std::string&); , m_temporary_name_count(0)
{
size_t indent; }
std::string get_code() const { return m_ss.str(); }
void operator+=(const std::string& s) { *this << s; }
template <typename T> template <typename T>
friend CodeWriter& operator<<(CodeWriter& out, const T& obj) friend CodeWriter& operator<<(CodeWriter& out, const T& obj)
{ {
...@@ -66,7 +67,15 @@ public: ...@@ -66,7 +67,15 @@ public:
return out; return out;
} }
std::string generate_temporary_name(std::string prefix = "tempvar"); std::string generate_temporary_name(const std::string& prefix = "tempvar")
{
std::stringstream ss;
ss << prefix << m_temporary_name_count;
m_temporary_name_count++;
return ss.str();
}
void block_begin() void block_begin()
{ {
...@@ -80,6 +89,8 @@ public: ...@@ -80,6 +89,8 @@ public:
*this << "}\n"; *this << "}\n";
} }
size_t indent;
private: private:
std::stringstream m_ss; std::stringstream m_ss;
bool m_pending_indent; bool m_pending_indent;
......
...@@ -22,30 +22,39 @@ set(SRC ...@@ -22,30 +22,39 @@ set(SRC
cuda_emitter.cpp cuda_emitter.cpp
cudnn_emitter.cpp cudnn_emitter.cpp
cublas_emitter.cpp cublas_emitter.cpp
host_emitter.cpp
gpu_backend.cpp gpu_backend.cpp
gpu_call_frame.cpp
gpu_cuda_context_manager.cpp gpu_cuda_context_manager.cpp
gpu_cuda_function_builder.cpp gpu_cuda_function_builder.cpp
gpu_cuda_function_pool.cpp gpu_cuda_function_pool.cpp
gpu_cuda_kernel_builder.cpp gpu_cuda_kernel_builder.cpp
gpu_emitter.cpp gpu_emitter.cpp
gpu_external_function.cpp gpu_compiled_function.cpp
gpu_internal_function.cpp
gpu_invoke.cpp gpu_invoke.cpp
gpu_kernel_args.cpp
gpu_kernel_emitters.cpp gpu_kernel_emitters.cpp
gpu_memory_manager.cpp gpu_memory_manager.cpp
gpu_primitive_emitter.cpp gpu_primitive_emitter.cpp
gpu_runtime_constructor.cpp
gpu_runtime_context.cpp gpu_runtime_context.cpp
gpu_tensor_wrapper.cpp gpu_tensor_wrapper.cpp
gpu_tensor.cpp gpu_tensor.cpp
gpu_util.cpp gpu_util.cpp
type_info.cpp type_info.cpp
pass/gpu_batch_norm_cache.cpp
pass/gpu_layout.cpp pass/gpu_layout.cpp
pass/tensor_memory_reservation.cpp
gpu_kernel_args.cpp
pass/gpu_rnn_fusion.cpp pass/gpu_rnn_fusion.cpp
pass/gpu_batch_norm_cache.cpp pass/tensor_memory_reservation.cpp
op/batch_norm.cpp op/batch_norm.cpp
op/rnn.cpp op/rnn.cpp
) )
if (NOT NGRAPH_DEX_ONLY)
list(APPEND SRC gpu_external_function.cpp)
endif()
set(CUDA_INC set(CUDA_INC
${PROJECT_SOURCE_DIR}/src/ ${PROJECT_SOURCE_DIR}/src/
) )
...@@ -141,7 +150,12 @@ if (NGRAPH_GPU_ENABLE) ...@@ -141,7 +150,12 @@ if (NGRAPH_GPU_ENABLE)
VERSION ${NGRAPH_VERSION} VERSION ${NGRAPH_VERSION}
SOVERSION ${NGRAPH_API_VERSION}) SOVERSION ${NGRAPH_API_VERSION})
endif() endif()
target_link_libraries(gpu_backend PUBLIC ngraph codegen) target_link_libraries(gpu_backend PUBLIC ngraph)
if (NGRAPH_DEX_ONLY)
target_compile_definitions(gpu_backend PRIVATE "NGRAPH_DEX_ONLY")
else()
target_link_libraries(gpu_backend PUBLIC codegen)
endif()
find_library(CUDA_nvrtc_LIBRARY nvrtc find_library(CUDA_nvrtc_LIBRARY nvrtc
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
find_library(CUDA_cuda_LIBRARY cuda find_library(CUDA_cuda_LIBRARY cuda
......
...@@ -1688,20 +1688,26 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_ ...@@ -1688,20 +1688,26 @@ size_t runtime::gpu::CUDNNEmitter::build_pooling(const cudnnPoolingMode_t& pool_
{ {
pool.reset(new gpu::primitive{ pool.reset(new gpu::primitive{
[=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) { [=, &desc, &input_desc, &output_desc](void** inputs, void** outputs) {
void* y = inputs[0];
void* dy = inputs[0];
void* x = inputs[0];
if (pool_op == 0 || pool_op == 3)
{
y = inputs[2];
dy = inputs[1];
x = inputs[0];
}
CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle, CUDNN_SAFE_CALL(cudnnPoolingBackward(*m_ctx->cudnn_handle,
desc, desc,
alpha, alpha,
// output (wrt maxpool) tensor
output_desc, output_desc,
inputs[2], y,
// adjoint of output
output_desc, output_desc,
inputs[1], dy,
// input (wrt maxpool) tensor
input_desc, input_desc,
inputs[0], x,
beta, beta,
// adjoint of input // adjoint of input (dx)
input_desc, input_desc,
outputs[0])); outputs[0]));
debug_sync(); debug_sync();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp" #include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor.hpp" #include "ngraph/runtime/gpu/gpu_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_backend.hpp" #include "ngraph/runtime/hybrid/hybrid_backend.hpp"
...@@ -120,13 +121,13 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor( ...@@ -120,13 +121,13 @@ shared_ptr<runtime::Tensor> runtime::gpu::GPU_Backend::create_tensor(
runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func) runtime::Handle runtime::gpu::GPU_Backend::compile(shared_ptr<Function> func)
{ {
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr) if (instance.m_compiled_function == nullptr)
{ {
m_context->bind_cuda_context_to_thread(); m_context->bind_cuda_context_to_thread();
instance.m_external_function = make_shared<GPU_ExternalFunction>(func, m_context); instance.m_compiled_function = runtime::gpu::GPUCompiledFunction::make(func, m_context);
instance.m_external_function->m_emit_timing = instance.m_performance_counters_enabled; instance.m_compiled_function->m_emit_timing = instance.m_performance_counters_enabled;
instance.m_external_function->compile(); instance.m_compiled_function->compile();
instance.m_compiled_function = instance.m_external_function->m_compiled_function; instance.m_runtime = instance.m_compiled_function->m_runtime;
instance.m_inputs.resize(func->get_parameters().size()); instance.m_inputs.resize(func->get_parameters().size());
instance.m_outputs.resize(func->get_output_size()); instance.m_outputs.resize(func->get_output_size());
} }
...@@ -156,7 +157,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, ...@@ -156,7 +157,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
const vector<shared_ptr<runtime::Tensor>>& inputs) const vector<shared_ptr<runtime::Tensor>>& inputs)
{ {
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function == nullptr) if (instance.m_compiled_function == nullptr)
{ {
throw runtime_error("compile() must be called before call()."); throw runtime_error("compile() must be called before call().");
} }
...@@ -169,7 +170,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func, ...@@ -169,7 +170,7 @@ bool runtime::gpu::GPU_Backend::call(shared_ptr<Function> func,
initialize_io(instance.m_outputs.data(), outputs); initialize_io(instance.m_outputs.data(), outputs);
auto ctx = m_context->m_runtime_context.get(); auto ctx = m_context->m_runtime_context.get();
instance.m_compiled_function(instance.m_inputs.data(), instance.m_outputs.data(), ctx); instance.m_runtime(instance.m_inputs.data(), instance.m_outputs.data(), ctx);
return true; return true;
} }
...@@ -182,7 +183,7 @@ void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> fu ...@@ -182,7 +183,7 @@ void runtime::gpu::GPU_Backend::remove_compiled_function(shared_ptr<Function> fu
void runtime::gpu::GPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable) void runtime::gpu::GPU_Backend::enable_performance_data(shared_ptr<Function> func, bool enable)
{ {
FunctionInstance& instance = m_function_map[func]; FunctionInstance& instance = m_function_map[func];
if (instance.m_external_function != nullptr) if (instance.m_compiled_function != nullptr)
{ {
throw runtime_error("Performance data collection must be enabled prior to compiling."); throw runtime_error("Performance data collection must be enabled prior to compiling.");
} }
...@@ -197,27 +198,9 @@ vector<runtime::PerformanceCounter> ...@@ -197,27 +198,9 @@ vector<runtime::PerformanceCounter>
if (it != m_function_map.end()) if (it != m_function_map.end())
{ {
const FunctionInstance& instance = it->second; const FunctionInstance& instance = it->second;
if (instance.m_external_function != nullptr) if (instance.m_compiled_function != nullptr)
{ {
auto* engine = instance.m_external_function->m_execution_engine.get(); instance.m_compiled_function->get_performance_data(rc);
if (engine)
{
auto get_count = engine->find_function<size_t()>("get_debug_timer_count");
auto get_name = engine->find_function<const char*(size_t)>("get_debug_timer_name");
auto get_microseconds =
engine->find_function<size_t(size_t)>("get_debug_timer_microseconds");
auto get_call_count =
engine->find_function<size_t(size_t)>("get_debug_timer_call_count");
if (get_count && get_name && get_microseconds && get_call_count)
{
size_t count = get_count();
for (size_t i = 0; i < count; i++)
{
rc.push_back({get_name(i), get_microseconds(i), get_call_count(i)});
}
}
}
} }
} }
return rc; return rc;
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
static size_t alignment = 64; static size_t alignment = 64;
class GPU_ExternalFunction; class GPUCompiledFunction;
class GPUPrimitiveEmitter; class GPUPrimitiveEmitter;
struct GPURuntimeContext; struct GPURuntimeContext;
class CudaContextManager; class CudaContextManager;
...@@ -83,9 +83,9 @@ namespace ngraph ...@@ -83,9 +83,9 @@ namespace ngraph
class FunctionInstance class FunctionInstance
{ {
public: public:
std::shared_ptr<GPU_ExternalFunction> m_external_function; std::shared_ptr<GPUCompiledFunction> m_compiled_function;
bool m_performance_counters_enabled = false; bool m_performance_counters_enabled = false;
EntryPoint m_compiled_function; EntryPoint m_runtime;
std::vector<void*> m_inputs; std::vector<void*> m_inputs;
std::vector<void*> m_outputs; std::vector<void*> m_outputs;
}; };
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
using namespace ngraph;
runtime::gpu::GPUCallFrame::GPUCallFrame(const size_t& num_inputs, const size_t& num_outputs)
: m_inputs(num_inputs, nullptr)
, m_outputs(num_outputs, nullptr)
{
}
void runtime::gpu::GPUCallFrame::resolve_reservations(
const GPUCompiledFunction* compiled_function,
const std::unordered_map<std::string, size_t>& memory_reservations)
{
auto& mem_primitives = compiled_function->get_primitive_emitter()->get_memory_primitives();
for (auto const& p : memory_reservations)
{
// mem_primitives may return pointers for constant or workspace reservations
m_memory_reservations[p.first] = static_cast<unsigned char*>(mem_primitives.at(p.second)());
}
}
void runtime::gpu::GPUCallFrame::resolve_inputs(void** inputs, size_t num_inputs)
{
// num_inputs is > 0 iff we are resolving inputs from a nested function call
if (num_inputs == 0)
{
num_inputs = m_inputs.size();
}
for (size_t i = 0; i < num_inputs; i++)
{
void* input = inputs[i];
m_inputs[i] = static_cast<unsigned char*>(input);
}
}
void runtime::gpu::GPUCallFrame::resolve_outputs(void** outputs, size_t num_outputs)
{
// num_outputs is > 0 iff we are resolving outputs from a nested function call
if (num_outputs == 0)
{
num_outputs = m_outputs.size();
}
for (size_t i = 0; i < num_outputs; i++)
{
void* output = outputs[i];
m_outputs[i] = static_cast<unsigned char*>(output);
}
}
// returns pointers of any GPUTensorWrapper::TensorType
std::vector<void*>
runtime::gpu::GPUCallFrame::get_tensor_io(const std::vector<GPUTensorWrapper>& tensors)
{
std::vector<void*> ptrs;
for (auto const& tensor : tensors)
{
auto offset = tensor.get_offset();
auto ptr = get_pointer(offset.first, offset.second, tensor.get_name());
ptrs.push_back(ptr);
}
return ptrs;
}
void* runtime::gpu::GPUCallFrame::get_pointer(const TensorType& type,
const size_t& offset,
const std::string& name)
{
switch (type)
{
case TensorType::CONSTANT:
case TensorType::INTERMEDIATE:
return static_cast<void*>(m_memory_reservations.at(name) + offset);
case TensorType::INPUT: return static_cast<void*>(m_inputs.at(offset));
case TensorType::OUTPUT: return static_cast<void*>(m_outputs.at(offset));
case TensorType::UNKNOWN:
default: throw ngraph_error("GPUCallFrame encountered unknown or uninitialized tensor type");
};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <unordered_map>
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUCallFrame
{
public:
using TensorType = GPUTensorWrapper::TensorType;
GPUCallFrame(const size_t& num_inputs, const size_t& num_outputs);
void resolve_reservations(
const GPUCompiledFunction* compiled_function,
const std::unordered_map<std::string, size_t>& memory_reservations);
void resolve_inputs(void** inputs, size_t num_inputs = 0);
void resolve_outputs(void** outputs, size_t num_outputs = 0);
std::vector<void*> get_tensor_io(const std::vector<GPUTensorWrapper>& tensors);
private:
void* get_pointer(const TensorType& type,
const size_t& offset,
const std::string& name = "");
std::unordered_map<std::string, unsigned char*> m_memory_reservations;
std::vector<unsigned char*> m_inputs;
std::vector<unsigned char*> m_outputs;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdlib>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <fstream>
#include <locale>
#include <mutex>
#include <string>
#include <tuple>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp"
#include "ngraph/runtime/gpu/gpu_internal_function.hpp"
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/runtime/gpu/op/rnn.hpp"
#include "ngraph/runtime/gpu/pass/gpu_batch_norm_cache.hpp"
#include "ngraph/runtime/gpu/pass/gpu_layout.hpp"
#include "ngraph/runtime/gpu/pass/gpu_rnn_fusion.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
using namespace std;
using namespace ngraph;
std::string runtime::gpu::GPUCompiledFunction::get_output_dir()
{
static std::string output_dir = "gpu_codegen";
return output_dir;
}
size_t runtime::gpu::GPUCompiledFunction::get_memory_alignment()
{
static size_t memory_pool_alignment = 64;
return memory_pool_alignment;
}
static std::mutex s_compilation;
class GPUStaticInitializers
{
public:
GPUStaticInitializers()
{
file_util::remove_directory(runtime::gpu::GPUCompiledFunction::get_output_dir());
file_util::make_directory(runtime::gpu::GPUCompiledFunction::get_output_dir());
}
};
static GPUStaticInitializers s_static_initializers;
runtime::gpu::GPUCompiledFunction::GPUCompiledFunction(
const shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
: m_runtime(nullptr)
, m_function(function)
, m_emit_timing(false)
, m_is_compiled(false)
, m_shared_context(shared_context)
{
}
runtime::gpu::GPUCompiledFunction::~GPUCompiledFunction()
{
}
std::vector<std::string> get_case_variants(std::vector<std::string> cases)
{
std::vector<std::string> results;
for (auto& c : cases)
{
results.push_back(c);
if (std::all_of(c.begin(), c.end(), ::isdigit))
{
continue;
}
for (auto i = 0u; i < c.size(); i++)
{
c[i] = std::toupper(c[i], std::locale());
if (i == 0)
{
results.emplace_back(c);
}
}
results.emplace_back(c);
}
return results;
}
std::shared_ptr<runtime::gpu::GPUCompiledFunction> runtime::gpu::GPUCompiledFunction::make(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context)
{
#if defined(NGRAPH_DEX_ONLY)
return std::make_shared<runtime::gpu::GPUInternalFunction>(function, shared_context);
#else
// For now codegen is default unless explicitly disabled
bool use_codegen = true;
if (auto env = std::getenv("NGRAPH_CODEGEN"))
{
std::string env_codegen(env);
for (auto& opt : get_case_variants({"0", "false"}))
{
if (env_codegen == opt)
{
use_codegen = false;
}
}
}
if (use_codegen)
{
return std::make_shared<runtime::gpu::GPUExternalFunction>(function, shared_context);
}
else
{
return std::make_shared<runtime::gpu::GPUInternalFunction>(function, shared_context);
}
#endif
}
void runtime::gpu::GPUCompiledFunction::compile()
{
if (m_is_compiled)
{
return;
}
std::unique_lock<std::mutex> lock(s_compilation);
m_function_name = m_function->get_name();
auto allocator = std::make_shared<runtime::gpu::GPUAllocator>(
m_shared_context->m_primitive_emitter->get_memory_allocator());
ngraph::pass::Manager pass_manager;
#if CUDNN_VERSION >= 7200
// recurrent network fusion
pass_manager.register_pass<runtime::gpu::pass::LSTMFusion>();
pass_manager.register_pass<runtime::gpu::pass::RNNFusion>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::gpu::pass::MultiLayerRNNFusion>();
#else
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
#endif
pass_manager.register_pass<runtime::gpu::pass::BatchNormCache>();
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<runtime::gpu::pass::GPULayout>(this);
pass_manager.register_pass<ngraph::pass::AssignLayout<descriptor::layout::DenseTensorLayout>>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(get_memory_alignment());
pass_manager.register_pass<runtime::gpu::pass::TensorMemoryReservation>(
*allocator, m_tensor_memory_buffers);
string dump_filename = file_util::path_join(get_output_dir(), m_function_name + "_ops.txt");
pass_manager.register_pass<ngraph::pass::DumpSorted>(dump_filename);
pass_manager.run_passes(m_function);
for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
{
m_function_ordered_ops.emplace(current_function, current_function->get_ordered_ops());
}
add_passes(pass_manager);
emit();
// allocate device buffers for primitive arguments and workspace
allocator->close();
m_shared_context->m_primitive_emitter->allocate_primitive_memory();
compile_function();
m_is_compiled = true;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#define EMIT_ARGS \
runtime::gpu::GPUCompiledFunction *compiled_function, const std::string &function_name, \
const Node *node, const std::vector<runtime::gpu::GPUTensorWrapper> &args, \
const std::vector<runtime::gpu::GPUTensorWrapper> &out
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_Emitter;
struct GPURuntimeContext;
class GPUCompiledFunction
{
friend class GPU_Backend;
public:
GPUCompiledFunction(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUCompiledFunction();
static std::shared_ptr<GPUCompiledFunction>
make(const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
std::unique_ptr<runtime::gpu::GPURuntimeContext>& ctx();
const std::unique_ptr<GPUPrimitiveEmitter>& get_primitive_emitter() const
{
return m_shared_context->m_primitive_emitter;
}
virtual std::string
add_to_runtime(size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) = 0;
virtual std::string
add_call_to_runtime(const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) = 0;
void compile();
virtual void
get_performance_data(std::vector<runtime::PerformanceCounter>& rc) const = 0;
static size_t get_memory_alignment();
static std::string get_output_dir();
protected:
virtual void compile_function() = 0;
virtual void add_passes(ngraph::pass::Manager& pass_manager) = 0;
virtual void emit() = 0;
EntryPoint m_runtime;
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) = 0;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) = 0;
std::shared_ptr<ngraph::Function> m_function;
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>
m_function_ordered_ops;
bool m_emit_timing;
bool m_is_compiled;
size_t m_offset;
std::string m_function_name;
std::unordered_map<std::string, size_t> m_tensor_memory_buffers;
std::shared_ptr<GPU_Backend::BackendContext> m_shared_context;
};
}
}
}
This diff is collapsed.
...@@ -19,9 +19,8 @@ ...@@ -19,9 +19,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp" #include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph namespace ngraph
...@@ -33,21 +32,21 @@ namespace ngraph ...@@ -33,21 +32,21 @@ namespace ngraph
class GPU_Emitter class GPU_Emitter
{ {
public: public:
static std::function<void(EMIT_ARGS)> get_emit_function(const Node& node); static std::function<std::string(EMIT_ARGS)> get_emit_function(const Node& node);
// This defines a collection of function declarations like this // This defines a collection of function declarations like this
// static void emit_Abs(EMIT_ARGS); // static std::string emit_Abs(EMIT_ARGS);
// static void emit_Acos(EMIT_ARGS); // static std::string emit_Acos(EMIT_ARGS);
#define NGRAPH_OP(a, b) static void emit_##a(EMIT_ARGS); #define NGRAPH_OP(a, b) static std::string emit_##a(EMIT_ARGS);
#include "ngraph/runtime/gpu/op/op_tbl.hpp" #include "ngraph/runtime/gpu/op/op_tbl.hpp"
#undef NGRAPH_OP #undef NGRAPH_OP
template <typename T> template <typename T>
static void emit_elementwise(EMIT_ARGS) static std::string emit_elementwise(EMIT_ARGS)
{ {
if (out[0].get_size() == 0) if (out[0].get_size() == 0)
{ {
return; return "";
} }
else if (out.size() > 1) else if (out.size() > 1)
{ {
...@@ -55,29 +54,22 @@ namespace ngraph ...@@ -55,29 +54,22 @@ namespace ngraph
"Multi-output elementwise ops are not currently supported."); "Multi-output elementwise ops are not currently supported.");
} }
auto& cuda_emitter = auto& cuda_emitter =
external_function->get_primitive_emitter()->get_cuda_emitter(); compiled_function->get_primitive_emitter()->get_cuda_emitter();
writer.block_begin(); std::vector<std::string> dtypes;
for (auto& arg : args)
{ {
std::vector<std::string> dtypes; dtypes.push_back(arg.get_type());
for (auto& arg : args)
{
dtypes.push_back(arg.get_type());
}
dtypes.push_back(out[0].get_type());
auto ew_index =
cuda_emitter->build_elementwise<T>(dtypes, out[0].get_shape());
writer << "void* input[] = {" << node_names(args) << "};\n";
writer << "void* output[] = {" << node_names(out) << "};\n";
writer << "gpu::invoke_primitive(ctx, " << ew_index
<< ", input, output);\n";
} }
writer.block_end(); dtypes.push_back(out[0].get_type());
auto ew_index = cuda_emitter->build_elementwise<T>(dtypes, out[0].get_shape());
return compiled_function->add_to_runtime(ew_index, function_name, args, out);
} }
static void emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t); static std::string emit_ArgReduce(EMIT_ARGS, cudnnReduceTensorOp_t);
static void emit_Sum_0(EMIT_ARGS); static std::string emit_Sum_0(EMIT_ARGS);
static void emit_Sum_1(EMIT_ARGS); static std::string emit_Sum_1(EMIT_ARGS);
/// \brief Create a list of node names for each arg in args /// \brief Create a list of node names for each arg in args
/// \param args list of tensor arguments /// \param args list of tensor arguments
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#pragma once #pragma once
#if !defined(NGRAPH_DEX_ONLY)
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
...@@ -32,14 +34,10 @@ ...@@ -32,14 +34,10 @@
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp" #include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp" #include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp" #include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#define EMIT_ARGS \
runtime::gpu::GPU_ExternalFunction *external_function, codegen::CodeWriter &writer, \
const Node *node, const std::vector<runtime::gpu::GPUTensorWrapper> &args, \
const std::vector<runtime::gpu::GPUTensorWrapper> &out
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -49,37 +47,41 @@ namespace ngraph ...@@ -49,37 +47,41 @@ namespace ngraph
class GPU_Emitter; class GPU_Emitter;
struct GPURuntimeContext; struct GPURuntimeContext;
class GPU_ExternalFunction class GPUExternalFunction : public GPUCompiledFunction
{ {
friend class GPU_Backend;
public: public:
GPU_ExternalFunction(const std::shared_ptr<ngraph::Function>& function, GPUExternalFunction(
std::shared_ptr<GPU_Backend::BackendContext>& shared_context); const std::shared_ptr<ngraph::Function>& function,
~GPU_ExternalFunction(); const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUExternalFunction();
std::unique_ptr<runtime::gpu::GPURuntimeContext>& ctx();
const std::unique_ptr<GPUPrimitiveEmitter>& get_primitive_emitter() const virtual std::string
{ add_to_runtime(size_t primitive_index,
return m_shared_context->m_primitive_emitter; const std::string& function_name,
} const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
static const size_t s_memory_pool_alignment; virtual std::string add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual void get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const override;
protected: protected:
void compile(); virtual void compile_function() override;
virtual void add_passes(ngraph::pass::Manager& pass_manager) override;
EntryPoint m_compiled_function; virtual void emit() override;
private: private:
// For non-destructive passthrough kernels, propagate function /// \brief Create a list of node names for each arg in args
// input buffers to internal ops /// \param args list of tensor arguments
void propagate_in_place_input(ngraph::descriptor::Output* output, /// \param arg_indexes a list of indexes into args for which args to include in
std::string input_name); /// the output list, so {1, 2} will include args 1 and 2 and skip 0.
// For in-place kernels, propagate function output buffers to /// \ return returns a string containing "arg0_name, arg1_name, etc."
// internal ops std::string node_names(const std::vector<runtime::gpu::GPUTensorWrapper>& args,
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output, std::initializer_list<int> arg_indexes = {});
std::string output_name);
void emit_header(); void emit_header();
void emit_timer_functions(); void emit_timer_functions();
void emit_constant_declarations(); void emit_constant_declarations();
...@@ -88,35 +90,31 @@ namespace ngraph ...@@ -88,35 +90,31 @@ namespace ngraph
void emit_debug_function_entry(Node* node); void emit_debug_function_entry(Node* node);
void emit_debug_function_exit(Node* node); void emit_debug_function_exit(Node* node);
void emit_temp_mem_pool_allocation(std::shared_ptr<Function> current_function); void emit_temp_mem_pool_allocation(std::shared_ptr<Function> current_function);
void emit_op(EMIT_ARGS);
void store_emitted_functions(const std::string& code); void store_emitted_functions(const std::string& code);
std::string emit_op(EMIT_ARGS);
std::string emit_op_as_function(const Node& node, const std::string& function_name); std::string emit_op_as_function(const Node& node, const std::string& function_name);
std::string strip_comments(const std::string& s) const; std::string strip_comments(const std::string& s) const;
static const std::string& get_pch_header_source(); static const std::string& get_pch_header_source();
static const std::string& get_header_source(); static const std::string& get_header_source();
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) override;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) override;
codegen::CodeWriter m_writer; codegen::CodeWriter m_writer;
std::string m_common_function_string;
std::unique_ptr<codegen::Compiler> m_compiler; std::unique_ptr<codegen::Compiler> m_compiler;
std::unique_ptr<codegen::ExecutionEngine> m_execution_engine; std::unique_ptr<codegen::ExecutionEngine> m_execution_engine;
std::shared_ptr<ngraph::Function> m_function;
std::map<std::string, size_t> m_name_index_map; std::map<std::string, size_t> m_name_index_map;
std::unordered_map<std::string, std::string> m_variable_name_map; std::unordered_map<std::string, std::string> m_variable_name_map;
std::unordered_map<Node*, Node*> m_node_function_map; std::unordered_map<Node*, Node*> m_node_function_map;
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>
m_function_ordered_ops;
bool m_emit_timing;
bool m_is_compiled;
size_t m_offset;
std::string m_function_name;
std::unordered_map<std::string, size_t> m_tensor_memory_buffers;
std::shared_ptr<GPU_Backend::BackendContext> m_shared_context;
}; };
} }
} }
} }
#endif // !defined(NGRAPH_DEX_ONLY)
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <tuple>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPU_Emitter;
class GPURuntimeConstructor;
struct GPURuntimeContext;
class GPUInternalFunction : public GPUCompiledFunction
{
public:
GPUInternalFunction(
const std::shared_ptr<ngraph::Function>& function,
const std::shared_ptr<GPU_Backend::BackendContext>& shared_context);
virtual ~GPUInternalFunction();
virtual std::string
add_to_runtime(size_t primitive_index,
const std::string& function_name,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual std::string add_call_to_runtime(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) override;
virtual void get_performance_data(
std::vector<runtime::PerformanceCounter>& rc) const override;
protected:
virtual void compile_function() override;
virtual void add_passes(ngraph::pass::Manager& pass_manager) override;
virtual void emit() override;
private:
void build_functions();
std::string emit_op(EMIT_ARGS);
std::string
compose_manifest(size_t primitive_index,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out) const;
void save_manifest_to_disk() const;
// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
virtual void propagate_in_place_input(ngraph::descriptor::Output* output,
const std::string& input_name) override;
// For in-place kernels, propagate function output buffers to
// internal ops
virtual void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
const std::string& output_name) override;
std::unordered_map<
std::string,
std::tuple<runtime::gpu::GPUTensorWrapper::TensorType, size_t, std::string>>
m_variable_name_map;
std::unique_ptr<GPURuntimeConstructor> m_runtime_constructor;
std::shared_ptr<codegen::CodeWriter> m_trace;
codegen::CodeWriter m_manifest;
};
}
}
}
...@@ -24,6 +24,7 @@ using namespace ngraph::runtime::gpu; ...@@ -24,6 +24,7 @@ using namespace ngraph::runtime::gpu;
GPUPrimitiveEmitter::GPUPrimitiveEmitter() GPUPrimitiveEmitter::GPUPrimitiveEmitter()
: m_memory_manager(this) : m_memory_manager(this)
, m_host_parameters(new GPUHostParameters) , m_host_parameters(new GPUHostParameters)
, m_host_emitter(new HostEmitter(this, nullptr))
, m_cuda_emitter(new CUDAEmitter(this, nullptr, nullptr)) , m_cuda_emitter(new CUDAEmitter(this, nullptr, nullptr))
, m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr)) , m_cudnn_emitter(new CUDNNEmitter(this, nullptr, nullptr))
, m_cublas_emitter(new CUBLASEmitter(this, nullptr)) , m_cublas_emitter(new CUBLASEmitter(this, nullptr))
...@@ -33,6 +34,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter() ...@@ -33,6 +34,7 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter()
GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx) GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx)
: m_memory_manager(this) : m_memory_manager(this)
, m_host_parameters(new GPUHostParameters) , m_host_parameters(new GPUHostParameters)
, m_host_emitter(new HostEmitter(this, ctx.get()))
, m_cuda_emitter(new CUDAEmitter(this, ctx.get(), this->m_host_parameters)) , m_cuda_emitter(new CUDAEmitter(this, ctx.get(), this->m_host_parameters))
, m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters)) , m_cudnn_emitter(new CUDNNEmitter(this, ctx.get(), this->m_host_parameters))
, m_cublas_emitter(new CUBLASEmitter(this, ctx.get())) , m_cublas_emitter(new CUBLASEmitter(this, ctx.get()))
...@@ -40,6 +42,10 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext ...@@ -40,6 +42,10 @@ GPUPrimitiveEmitter::GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext
{ {
} }
std::unique_ptr<HostEmitter>& GPUPrimitiveEmitter::get_host_emitter()
{
return m_host_emitter;
}
std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter() std::unique_ptr<CUDAEmitter>& GPUPrimitiveEmitter::get_cuda_emitter()
{ {
return m_cuda_emitter; return m_cuda_emitter;
...@@ -48,7 +54,6 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter() ...@@ -48,7 +54,6 @@ std::unique_ptr<CUDNNEmitter>& GPUPrimitiveEmitter::get_cudnn_emitter()
{ {
return m_cudnn_emitter; return m_cudnn_emitter;
} }
std::unique_ptr<CUBLASEmitter>& GPUPrimitiveEmitter::get_cublas_emitter() std::unique_ptr<CUBLASEmitter>& GPUPrimitiveEmitter::get_cublas_emitter()
{ {
return m_cublas_emitter; return m_cublas_emitter;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_args.hpp" #include "ngraph/runtime/gpu/gpu_kernel_args.hpp"
#include "ngraph/runtime/gpu/gpu_memory_manager.hpp" #include "ngraph/runtime/gpu/gpu_memory_manager.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp" #include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/host_emitter.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,6 +35,7 @@ namespace ngraph ...@@ -34,6 +35,7 @@ namespace ngraph
public: public:
GPUPrimitiveEmitter(); GPUPrimitiveEmitter();
GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx); GPUPrimitiveEmitter(const std::unique_ptr<GPURuntimeContext>& ctx);
std::unique_ptr<HostEmitter>& get_host_emitter();
std::unique_ptr<CUDAEmitter>& get_cuda_emitter(); std::unique_ptr<CUDAEmitter>& get_cuda_emitter();
std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter(); std::unique_ptr<CUDNNEmitter>& get_cudnn_emitter();
std::unique_ptr<CUBLASEmitter>& get_cublas_emitter(); std::unique_ptr<CUBLASEmitter>& get_cublas_emitter();
...@@ -59,6 +61,7 @@ namespace ngraph ...@@ -59,6 +61,7 @@ namespace ngraph
std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives; std::vector<std::unique_ptr<gpu::primitive>> m_managed_primitives;
GPUMemoryManager m_memory_manager; GPUMemoryManager m_memory_manager;
std::shared_ptr<GPUHostParameters> m_host_parameters; std::shared_ptr<GPUHostParameters> m_host_parameters;
std::unique_ptr<HostEmitter> m_host_emitter;
std::unique_ptr<CUDAEmitter> m_cuda_emitter; std::unique_ptr<CUDAEmitter> m_cuda_emitter;
std::unique_ptr<CUDNNEmitter> m_cudnn_emitter; std::unique_ptr<CUDNNEmitter> m_cudnn_emitter;
std::unique_ptr<CUBLASEmitter> m_cublas_emitter; std::unique_ptr<CUBLASEmitter> m_cublas_emitter;
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/gpu/gpu_runtime_constructor.hpp"
using namespace ngraph;
runtime::gpu::GPURuntimeConstructor::GPURuntimeConstructor(const op_order_t& ordered_ops)
{
for (auto const& ops : ordered_ops)
{
m_runtime[ops.first->get_name()].reserve(ops.second.size());
}
}
void runtime::gpu::GPURuntimeConstructor::add(const std::string& name, const op_runtime_t& step)
{
m_runtime[name].push_back(step);
}
void runtime::gpu::GPURuntimeConstructor::add_call(
const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out)
{
auto& runtime = m_runtime[callee];
auto call = [args, out, &runtime](GPUCallFrame& caller_frame, GPURuntimeContext* ctx) mutable {
// extract memory pointers from the callers stack
auto inputs = caller_frame.get_tensor_io(args);
auto outputs = caller_frame.get_tensor_io(out);
// create a new call frame for the nested function
GPUCallFrame callee_frame = caller_frame;
// resolve the inputs of the new call frame
callee_frame.resolve_inputs(inputs.data(), inputs.size());
callee_frame.resolve_outputs(outputs.data(), outputs.size());
for (auto const& step : runtime)
{
step(callee_frame, ctx);
}
};
add(caller, call);
}
runtime::gpu::EntryPoint runtime::gpu::GPURuntimeConstructor::build(const std::string& function,
GPUCallFrame& call_frame)
{
auto& runtime = m_runtime.at(function);
return [call_frame, &runtime](void** inputs, void** outputs, GPURuntimeContext* ctx) mutable {
call_frame.resolve_inputs(inputs);
call_frame.resolve_outputs(outputs);
for (auto const& step : runtime)
{
step(call_frame, ctx);
}
};
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/gpu/gpu_backend.hpp"
#include "ngraph/runtime/gpu/gpu_call_frame.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
namespace ngraph
{
namespace runtime
{
namespace gpu
{
class GPUCallFrame;
class GPURuntimeConstructor
{
public:
using op_runtime_t =
std::function<void(GPUCallFrame& call_frame, GPURuntimeContext* ctx)>;
using op_order_t =
std::unordered_map<std::shared_ptr<Function>, std::list<std::shared_ptr<Node>>>;
GPURuntimeConstructor(const op_order_t& ordered_ops);
void add(const std::string& name, const op_runtime_t& step);
void add_call(const std::string& caller,
const std::string& callee,
const std::vector<runtime::gpu::GPUTensorWrapper>& args,
const std::vector<runtime::gpu::GPUTensorWrapper>& out);
EntryPoint build(const std::string& function, GPUCallFrame& call_frame);
private:
std::unordered_map<std::string, std::vector<op_runtime_t>> m_runtime;
};
}
}
}
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <limits>
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
#include "ngraph/descriptor/layout/tensor_layout.hpp" #include "ngraph/descriptor/layout/tensor_layout.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/runtime/gpu/gpu_tensor_wrapper.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -25,6 +26,18 @@ runtime::gpu::GPUTensorWrapper::GPUTensorWrapper(const shared_ptr<descriptor::Te ...@@ -25,6 +26,18 @@ runtime::gpu::GPUTensorWrapper::GPUTensorWrapper(const shared_ptr<descriptor::Te
const string& alias) const string& alias)
: m_tensor(tv) : m_tensor(tv)
, m_alias(alias) , m_alias(alias)
, m_offset(std::make_pair(runtime::gpu::GPUTensorWrapper::TensorType::UNKNOWN,
std::numeric_limits<size_t>::max()))
{
}
runtime::gpu::GPUTensorWrapper::GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>& tv,
runtime::gpu::GPUTensorWrapper::TensorType type,
size_t offset,
const std::string& alias)
: m_tensor(tv)
, m_alias(alias)
, m_offset(std::make_pair(type, offset))
{ {
} }
...@@ -60,7 +73,25 @@ const std::string& runtime::gpu::GPUTensorWrapper::get_name() const ...@@ -60,7 +73,25 @@ const std::string& runtime::gpu::GPUTensorWrapper::get_name() const
} }
} }
const std::pair<runtime::gpu::GPUTensorWrapper::TensorType, size_t>&
runtime::gpu::GPUTensorWrapper::get_offset() const
{
return m_offset;
}
const std::string& runtime::gpu::GPUTensorWrapper::get_type() const const std::string& runtime::gpu::GPUTensorWrapper::get_type() const
{ {
return get_element_type().c_type_string(); return get_element_type().c_type_string();
} }
std::ostream& ngraph::runtime::gpu::operator<<(std::ostream& out,
const ngraph::runtime::gpu::GPUTensorWrapper& obj)
{
static std::vector<std::string> types{"CONSTANT", "INTERMEDIATE", "INPUT", "OUTPUT", "UNKNOWN"};
out << "gpu::tensor { name: " << obj.m_tensor->get_name()
<< " tensor_type: " << types.at(static_cast<size_t>(obj.m_offset.first))
<< ", offset/index: " << obj.m_offset.second << ", dtype: " << obj.get_element_type()
<< ", shape: " << obj.get_shape() << ", size: " << obj.get_size()
<< ", alias: " << obj.m_alias << " }";
return out;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <tuple>
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -28,6 +29,8 @@ namespace ngraph ...@@ -28,6 +29,8 @@ namespace ngraph
namespace gpu namespace gpu
{ {
class GPUTensorWrapper; class GPUTensorWrapper;
std::ostream& operator<<(std::ostream& out,
const ngraph::runtime::gpu::GPUTensorWrapper& obj);
} }
} }
} }
...@@ -35,7 +38,19 @@ namespace ngraph ...@@ -35,7 +38,19 @@ namespace ngraph
class ngraph::runtime::gpu::GPUTensorWrapper class ngraph::runtime::gpu::GPUTensorWrapper
{ {
public: public:
enum TensorType : std::size_t
{
CONSTANT,
INTERMEDIATE,
INPUT,
OUTPUT,
UNKNOWN
};
GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>&, const std::string& alias = ""); GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>&, const std::string& alias = "");
GPUTensorWrapper(const std::shared_ptr<descriptor::Tensor>&,
TensorType,
size_t,
const std::string& alias);
size_t get_size() const; size_t get_size() const;
const Shape& get_shape() const; const Shape& get_shape() const;
...@@ -43,8 +58,12 @@ public: ...@@ -43,8 +58,12 @@ public:
const element::Type& get_element_type() const; const element::Type& get_element_type() const;
const std::string& get_name() const; const std::string& get_name() const;
const std::string& get_type() const; const std::string& get_type() const;
const std::pair<TensorType, size_t>& get_offset() const;
friend std::ostream& ngraph::runtime::gpu::
operator<<(std::ostream& out, const ngraph::runtime::gpu::GPUTensorWrapper& obj);
private: private:
std::shared_ptr<descriptor::Tensor> m_tensor; std::shared_ptr<descriptor::Tensor> m_tensor;
std::string m_alias; std::string m_alias;
std::pair<TensorType, size_t> m_offset;
}; };
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include <cudnn.h> #include <cudnn.h>
#include <memory>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
...@@ -37,6 +38,12 @@ namespace ngraph ...@@ -37,6 +38,12 @@ namespace ngraph
std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor); std::pair<uint64_t, uint64_t> idiv_magic_u64(uint64_t divisor);
uint32_t idiv_ceil(int n, int d); uint32_t idiv_ceil(int n, int d);
template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args)
{
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
// This is commented out because it increases the compile time. // This is commented out because it increases the compile time.
// It should be moved to a debug header. // It should be moved to a debug header.
// template <typename T> // template <typename T>
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include <sstream>
#include <vector>
#include "ngraph/runtime/gpu/gpu_invoke.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_runtime_context.hpp"
#include "ngraph/runtime/gpu/host_emitter.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
runtime::gpu::HostEmitter::HostEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx)
: m_primitive_emitter(emitter)
, m_ctx(ctx)
{
}
size_t runtime::gpu::HostEmitter::build_memcpy(const cudaMemcpyKind& kind,
size_t size,
size_t dst,
size_t src)
{
std::stringstream ss;
ss << "memcpy" << kind << "_dst" << dst << "_src" << src << "_sz" << size;
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::unique_ptr<gpu::primitive> launch_kernel(
new gpu::primitive{[=](void** inputs, void** outputs) mutable {
CUDA_RT_SAFE_CALL(cudaMemcpy(outputs[dst], inputs[src], size, kind));
}});
return this->m_primitive_emitter->register_primitive(launch_kernel, hash);
}
size_t runtime::gpu::HostEmitter::build_zero_out(size_t dst, size_t size, bool is_local)
{
std::stringstream ss;
ss << "zero"
<< "_dst" << dst << "_sz" << size << "_local" << is_local;
std::string hash = ss.str();
// check if the requested kernel is already an inserted primitive
size_t primitive_index = m_primitive_emitter->lookup(hash);
if (primitive_index != std::numeric_limits<size_t>::max())
{
return primitive_index;
}
std::unique_ptr<gpu::primitive> launch_kernel;
if (is_local)
{
launch_kernel.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
void* tensor = gpu::invoke_memory_primitive(m_ctx, dst);
CUDA_RT_SAFE_CALL(cudaMemset(tensor, 0, size));
}});
}
else
{
launch_kernel.reset(new gpu::primitive{[=](void** inputs, void** outputs) mutable {
CUDA_RT_SAFE_CALL(cudaMemset(outputs[dst], 0, size));
}});
}
return this->m_primitive_emitter->register_primitive(launch_kernel, hash);
}
...@@ -14,34 +14,39 @@ ...@@ -14,34 +14,39 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "code_writer.hpp" #pragma once
using namespace std; #include <functional>
using namespace ngraph; #include <memory>
#include <vector>
codegen::CodeWriter::CodeWriter() #include <cuda_runtime.h>
: indent(0)
, m_pending_indent(true)
, m_temporary_name_count(0)
{
}
string codegen::CodeWriter::get_code() const
{
return m_ss.str();
}
void codegen::CodeWriter::operator+=(const std::string& s) namespace ngraph
{ {
*this << s; namespace runtime
} {
namespace gpu
std::string codegen::CodeWriter::generate_temporary_name(std::string prefix) {
{ struct GPURuntimeContext;
std::stringstream ss; class GPUPrimitiveEmitter;
class HostEmitter
ss << prefix << m_temporary_name_count; {
m_temporary_name_count++; friend class GPUPrimitiveEmitter;
return ss.str(); public:
size_t build_memcpy(const cudaMemcpyKind& kind,
size_t size,
size_t dst = 0,
size_t src = 0);
size_t build_zero_out(size_t dst, size_t size, bool is_local = false);
private:
HostEmitter(GPUPrimitiveEmitter* emitter, GPURuntimeContext* ctx);
GPUPrimitiveEmitter* m_primitive_emitter;
GPURuntimeContext* m_ctx;
};
}
}
} }
...@@ -197,7 +197,7 @@ bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::share ...@@ -197,7 +197,7 @@ bool runtime::gpu::pass::GPULayout::run_on_call_graph(const std::list<std::share
auto handler = s_dispatcher.find(TI(n)); auto handler = s_dispatcher.find(TI(n));
if (handler != s_dispatcher.end()) if (handler != s_dispatcher.end())
{ {
handler->second(m_external_function, node); handler->second(m_compiled_function, node);
} }
} }
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#pragma once #pragma once
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/gpu/gpu_external_function.hpp" #include "ngraph/runtime/gpu/gpu_compiled_function.hpp"
#define LAYOUT_DECL(op_type) \ #define LAYOUT_DECL(op_type) \
layout<op_type>(ngraph::runtime::gpu::GPU_ExternalFunction * external_function, \ layout<op_type>(ngraph::runtime::gpu::GPUCompiledFunction * compiled_function, \
std::shared_ptr<ngraph::Node> node) std::shared_ptr<ngraph::Node> node)
namespace ngraph namespace ngraph
...@@ -32,27 +32,26 @@ namespace ngraph ...@@ -32,27 +32,26 @@ namespace ngraph
namespace pass namespace pass
{ {
using LayoutFunction = using LayoutFunction =
std::function<void(GPU_ExternalFunction*, std::shared_ptr<ngraph::Node>)>; std::function<void(GPUCompiledFunction*, std::shared_ptr<ngraph::Node>)>;
using LayoutOpMap = std::unordered_map<std::type_index, LayoutFunction>; using LayoutOpMap = std::unordered_map<std::type_index, LayoutFunction>;
class GPULayout : public ngraph::pass::CallGraphPass class GPULayout : public ngraph::pass::CallGraphPass
{ {
public: public:
GPULayout(GPU_ExternalFunction* external_function) GPULayout(GPUCompiledFunction* compiled_function)
: m_external_function(external_function) : m_compiled_function(compiled_function)
{ {
} }
virtual bool virtual bool
run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override; run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
template <typename OP> template <typename OP>
static void static void layout(ngraph::runtime::gpu::GPUCompiledFunction* compiled_function,
layout(ngraph::runtime::gpu::GPU_ExternalFunction* external_function, std::shared_ptr<ngraph::Node> node);
std::shared_ptr<ngraph::Node> node);
private: private:
GPU_ExternalFunction* m_external_function; GPUCompiledFunction* m_compiled_function;
}; };
NodeVector insert_new_reshape_after(NodeVector& parents, NodeVector insert_new_reshape_after(NodeVector& parents,
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pass/manager_state.hpp" #include "ngraph/pass/manager_state.hpp"
#include "ngraph/runtime/gpu/gpu_memory_manager.hpp" #include "ngraph/runtime/gpu/gpu_memory_manager.hpp"
#include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp" #include "ngraph/runtime/gpu/pass/tensor_memory_reservation.hpp"
...@@ -28,13 +29,27 @@ using namespace std; ...@@ -28,13 +29,27 @@ using namespace std;
bool runtime::gpu::pass::TensorMemoryReservation::run_on_function(shared_ptr<Function> f) bool runtime::gpu::pass::TensorMemoryReservation::run_on_function(shared_ptr<Function> f)
{ {
bool reservation = false;
size_t mem_pool_size = f->get_temporary_pool_size(); size_t mem_pool_size = f->get_temporary_pool_size();
// intermediate memory reservation
if (mem_pool_size) if (mem_pool_size)
{ {
size_t pool_idx = m_allocator.reserve_workspace(mem_pool_size, false); size_t pool_idx = m_allocator.reserve_workspace(mem_pool_size, false);
m_memory_buffers.insert({f->get_name(), pool_idx}); m_memory_buffers.insert({f->get_name(), pool_idx});
reservation = true;
}
return true; // constant memory reservation
for (auto const& node : f->get_ops())
{
if (auto constant = std::dynamic_pointer_cast<ngraph::op::Constant>(node))
{
std::shared_ptr<descriptor::Tensor> tv = node->get_outputs()[0].get_tensor_ptr();
size_t idx = m_allocator.reserve_argspace(constant->get_data_ptr(), tv->size());
m_memory_buffers.insert({node->get_name(), idx});
reservation = true;
}
} }
return false;
return reservation;
} }
...@@ -57,3 +57,17 @@ shape_of_scalar ...@@ -57,3 +57,17 @@ shape_of_scalar
shape_of_vector shape_of_vector
shape_of_matrix shape_of_matrix
shape_of_5d shape_of_5d
# zero size axis needs to be implemented
# differently in gpu_emitter.cpp
# these should be re-enabled before
# merging to master
product_matrix_rows_zero
product_matrix_cols_zero
product_vector_zero
product_matrix_to_scalar_zero_by_zero
product_3d_eliminate_zero_dim
max_matrix_rows_zero_int32
reduce_matrix_rows_zero
reduce_matrix_cols_zero
reduce_vector_zero
reduce_matrix_to_scalar_zero_by_zero
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment