Commit 3d53e58a authored by fenglei.tian's avatar fenglei.tian

merge and resolve conflict with origin master

parents 39dc384d b5467550
...@@ -3279,6 +3279,45 @@ namespace ngraph ...@@ -3279,6 +3279,45 @@ namespace ngraph
<< to_string(sigmoid_index) << ");\n"; << to_string(sigmoid_index) << ");\n";
} }
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::SigmoidBackprop)
{
auto input_shape = args[0].get_shape();
auto delta_shape = args[1].get_shape();
auto result_shape = out[0].get_shape();
int input_1d_size = static_cast<int>(shape_size(input_shape));
int delta_1d_size = static_cast<int>(shape_size(delta_shape));
int result_1d_size = static_cast<int>(shape_size(result_shape));
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn::memory::desc(
{input_1d_size},
mkldnn_utils::get_mkldnn_data_type(args[0].get_element_type()),
mkldnn::memory::format::x);
auto delta_desc = mkldnn::memory::desc(
{delta_1d_size},
mkldnn_utils::get_mkldnn_data_type(args[1].get_element_type()),
mkldnn::memory::format::x);
auto result_desc = mkldnn::memory::desc(
{result_1d_size},
mkldnn_utils::get_mkldnn_data_type(out[0].get_element_type()),
mkldnn::memory::format::x);
size_t sigmoid_index =
mkldnn_emitter->build_sigmoid_backward(input_desc, delta_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(sigmoid_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0]) << ", "
<< args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1]) << ", "
<< args[1].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[2]) << ", "
<< out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(sigmoid_index) << ");\n";
}
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Softmax) void CPU_Emitter::EMITTER_DECL(ngraph::op::Softmax)
{ {
......
...@@ -252,6 +252,7 @@ static const runtime::cpu::OpMap dispatcher{ ...@@ -252,6 +252,7 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::ReluBackprop), &runtime::cpu::CPU_Emitter::emit<op::ReluBackprop>}, {TI(ngraph::op::ReluBackprop), &runtime::cpu::CPU_Emitter::emit<op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>}, {TI(ngraph::op::Sigmoid), &runtime::cpu::CPU_Emitter::emit<op::Sigmoid>},
{TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>}, {TI(ngraph::op::Softmax), &runtime::cpu::CPU_Emitter::emit<op::Softmax>},
{TI(ngraph::op::SigmoidBackprop), &runtime::cpu::CPU_Emitter::emit<op::SigmoidBackprop>},
}; };
runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction( runtime::cpu::CPU_ExternalFunction::CPU_ExternalFunction(
......
...@@ -513,6 +513,32 @@ size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_de ...@@ -513,6 +513,32 @@ size_t MKLDNNEmitter::build_sigmoid_forward(const mkldnn::memory::desc& input_de
return primitive_index; return primitive_index;
} }
size_t MKLDNNEmitter::build_sigmoid_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc)
{
size_t input_index = build_memory_primitive(input_desc);
size_t delta_index = build_memory_primitive(delta_desc);
size_t result_index = build_memory_primitive(result_desc);
// sigmoid forward primitive desc
mkldnn::eltwise_forward::primitive_desc sigmoid_fwd_pd =
mkldnn::eltwise_forward::primitive_desc(
{mkldnn::prop_kind::forward, mkldnn::algorithm::eltwise_logistic, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine);
size_t primitive_index = insert_primitive(new mkldnn::eltwise_backward(
{{mkldnn::algorithm::eltwise_logistic, delta_desc, input_desc, 0, 0},
mkldnn_utils::global_cpu_engine,
sigmoid_fwd_pd},
*m_mkldnn_primitives[input_index],
*m_mkldnn_primitives[delta_index],
*m_mkldnn_primitives[result_index]));
m_primitive_deps[primitive_index] = {input_index, delta_index, result_index};
return primitive_index;
}
size_t MKLDNNEmitter::build_elementwise_add( size_t MKLDNNEmitter::build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
...@@ -153,6 +153,10 @@ namespace ngraph ...@@ -153,6 +153,10 @@ namespace ngraph
size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc, size_t build_sigmoid_forward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& result_desc); const mkldnn::memory::desc& result_desc);
size_t build_sigmoid_backward(const mkldnn::memory::desc& input_desc,
const mkldnn::memory::desc& delta_desc,
const mkldnn::memory::desc& result_desc);
size_t build_elementwise_add( size_t build_elementwise_add(
const mkldnn::memory::desc& input0_data_desc, const mkldnn::memory::desc& input0_data_desc,
const mkldnn::memory::desc& input1_data_desc, const mkldnn::memory::desc& input1_data_desc,
......
...@@ -35,3 +35,24 @@ ngraph::op::Sigmoid::Sigmoid(std::shared_ptr<ngraph::Node> input) ...@@ -35,3 +35,24 @@ ngraph::op::Sigmoid::Sigmoid(std::shared_ptr<ngraph::Node> input)
{ {
add_output(input->get_element_type(), m_shape_input); add_output(input->get_element_type(), m_shape_input);
} }
ngraph::op::SigmoidBackprop::SigmoidBackprop(std::shared_ptr<Node> arg, std::shared_ptr<Node> delta)
: RequiresTensorViewArgs("SigmoidBackprop", {arg, delta})
{
if (arg->get_element_type() != delta->get_element_type())
{
throw ngraph_error("Argument and delta element types for Sigmoid backprop do not match");
}
if (arg->get_shape() != delta->get_shape())
{
throw ngraph_error("Argument and delta shape for Sigmoid backprop do not match");
}
set_value_type_checked(delta->get_element_type(), delta->get_shape());
}
void ngraph::op::Sigmoid::generate_adjoints(ngraph::autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto backprop = std::make_shared<op::SigmoidBackprop>(get_input_op(0), delta);
adjoints.add_delta(get_input_op(0), backprop);
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pragma once #pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -29,9 +30,32 @@ namespace ngraph ...@@ -29,9 +30,32 @@ namespace ngraph
Shape get_input_shape() const { return m_shape_input; } Shape get_input_shape() const { return m_shape_input; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
private: private:
Shape m_shape_input; Shape m_shape_input;
}; };
/// \brief Elementwise SigmoidBackprop operation.
///
class SigmoidBackprop : public util::RequiresTensorViewArgs
{
public:
/// \brief Constructs a SigmoidBackprop operation.
///
/// \param arg Node that produces the Sigmoid forward input tensor.
SigmoidBackprop(std::shared_ptr<ngraph::Node> arg, std::shared_ptr<ngraph::Node> delta);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<SigmoidBackprop>(new_args.at(0), new_args.at(1));
}
};
} }
} }
...@@ -316,6 +316,19 @@ namespace ngraph ...@@ -316,6 +316,19 @@ namespace ngraph
} }
} }
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::SigmoidBackprop)
{
auto sigmoid = static_cast<op::SigmoidBackprop*>(node);
if (node->get_input_element_type(0) == element::f32)
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
sigmoid->set_op_annotations(op_annotations);
}
}
template <> template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop) void CPUAssignment::ASSIGN_DECL(ngraph::op::ReluBackprop)
{ {
...@@ -386,6 +399,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{ ...@@ -386,6 +399,8 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>}, &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>}, {TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
}; };
bool runtime::cpu::pass::CPUAssignment::run_on_call_graph( bool runtime::cpu::pass::CPUAssignment::run_on_call_graph(
......
...@@ -568,6 +568,57 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid() ...@@ -568,6 +568,57 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid()
this->add_matcher(m); this->add_matcher(m);
} }
void ngraph::runtime::cpu::pass::CPUFusion::construct_sigmoid_bprop()
{
//construct variance
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_input = std::make_shared<op::Negative>(input);
auto exp_neg_input = std::make_shared<op::Exp>(neg_input);
// broadcast input
auto constant = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto broadcast_constant = std::make_shared<op::Broadcast>(constant, Shape{3, 4}, AxisSet{0, 1});
auto add_exp = std::make_shared<op::Add>(exp_neg_input, broadcast_constant);
// //auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
auto sigmoid_fwd = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto delta = std::make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
auto neg_delta = std::make_shared<op::Negative>(delta);
auto multiply_sigmoid_delta = std::make_shared<op::Multiply>(sigmoid_fwd, neg_delta);
auto divide_2 = std::make_shared<op::Divide>(multiply_sigmoid_delta, add_exp);
auto multiply_2 = std::make_shared<op::Multiply>(divide_2, exp_neg_input);
auto negtive_2 = std::make_shared<op::Negative>(multiply_2);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::gr_callback_fn callback =
[input, delta](pattern::Matcher& m) -> std::shared_ptr<Node> {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
if (m.match_root()->get_element_type() != element::f32)
{
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name() << " type is not float!";
return nullptr;
}
if (m.match_root()->get_shape().size() != pattern_map[input]->get_shape().size())
{
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
return nullptr;
}
auto dsigmoid =
std::make_shared<op::SigmoidBackprop>(pattern_map[input], pattern_map[delta]);
return dsigmoid;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(negtive_2, callback);
this->add_matcher(m);
}
void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias() void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
{ {
Shape shape{2, 2, 1, 1}; Shape shape{2, 2, 1, 1};
......
...@@ -44,6 +44,7 @@ public: ...@@ -44,6 +44,7 @@ public:
construct_zero_padded_reshaped_conv(); construct_zero_padded_reshaped_conv();
construct_zero_padded_conv(); construct_zero_padded_conv();
construct_sigmoid(); construct_sigmoid();
construct_sigmoid_bprop();
construct_conv_bias(); construct_conv_bias();
} }
...@@ -53,6 +54,7 @@ private: ...@@ -53,6 +54,7 @@ private:
void construct_conv_bias(); void construct_conv_bias();
void construct_fprop_bn(); void construct_fprop_bn();
void construct_sigmoid(); void construct_sigmoid();
void construct_sigmoid_bprop();
void construct_zero_padded_reshaped_conv(); void construct_zero_padded_reshaped_conv();
void construct_zero_padded_conv(); void construct_zero_padded_conv();
}; };
...@@ -960,6 +960,29 @@ namespace ngraph ...@@ -960,6 +960,29 @@ namespace ngraph
} }
} }
template <>
void CPULayout::LAYOUT_DECL(ngraph::op::SigmoidBackprop)
{
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node.get()))
{
auto input_layout =
runtime::cpu::mkldnn_utils::get_input_mkldnn_format(node.get(), 0);
vector<memory::format> prim_input_formats;
vector<memory::format> prim_output_formats;
//ensure delta and input have same layout
prim_input_formats.push_back(input_layout);
prim_input_formats.push_back(input_layout);
prim_output_formats.push_back(input_layout);
node =
insert_input_conversions(external_function, node, prim_input_formats);
set_output_layouts(node, prim_output_formats);
}
else
{
set_default_layouts(external_function, node);
}
}
template <> template <>
void CPULayout::LAYOUT_DECL(ngraph::op::ReluBackprop) void CPULayout::LAYOUT_DECL(ngraph::op::ReluBackprop)
{ {
...@@ -1095,6 +1118,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{ ...@@ -1095,6 +1118,8 @@ static const runtime::cpu::pass::LayoutOpMap s_dispatcher{
{TI(ngraph::op::ReluBackprop), {TI(ngraph::op::ReluBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>}, &runtime::cpu::pass::CPULayout::layout<ngraph::op::ReluBackprop>},
{TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>}, {TI(ngraph::op::Sigmoid), &runtime::cpu::pass::CPULayout::layout<ngraph::op::Sigmoid>},
{TI(ngraph::op::SigmoidBackprop),
&runtime::cpu::pass::CPULayout::layout<ngraph::op::SigmoidBackprop>},
}; };
bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) bool runtime::cpu::pass::CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
......
...@@ -19,25 +19,18 @@ ...@@ -19,25 +19,18 @@
#include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp" #include "ngraph/runtime/gpu/gpu_cuda_context_manager.hpp"
namespace ngraph using namespace ngraph;
runtime::gpu::CudaContextManager& runtime::gpu::CudaContextManager::instance()
{ {
namespace runtime
{
namespace gpu
{
CudaContextManager& CudaContextManager::instance()
{
static CudaContextManager manager; static CudaContextManager manager;
return manager; return manager;
} }
CudaContextManager::CudaContextManager() runtime::gpu::CudaContextManager::CudaContextManager()
{ {
CUDA_SAFE_CALL(cuInit(0)); CUDA_SAFE_CALL(cuInit(0));
CUDA_SAFE_CALL(cuDeviceGet(&m_device, 0)); CUDA_SAFE_CALL(cuDeviceGet(&m_device, 0));
CUDA_SAFE_CALL(cuCtxCreate(&m_context, 0, m_device)); CUDA_SAFE_CALL(cuCtxCreate(&m_context, 0, m_device));
m_context_ptr = std::make_shared<CUcontext>(m_context); m_context_ptr = std::make_shared<CUcontext>(m_context);
}
}
}
} }
...@@ -20,17 +20,13 @@ ...@@ -20,17 +20,13 @@
#include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_function_builder.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp" #include "ngraph/runtime/gpu/gpu_util.hpp"
namespace ngraph using namespace ngraph;
{
namespace runtime std::shared_ptr<CUfunction> runtime::gpu::CudaFunctionBuilder::get(const std::string& name,
{
namespace gpu
{
std::shared_ptr<CUfunction> CudaFunctionBuilder::get(const std::string& name,
const std::string& kernel, const std::string& kernel,
int number_of_options, int number_of_options,
const char** options) const char** options)
{ {
nvrtcProgram prog; nvrtcProgram prog;
NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog, NVRTC_SAFE_CALL(nvrtcCreateProgram(&prog,
kernel.c_str(), kernel.c_str(),
...@@ -49,8 +45,8 @@ namespace ngraph ...@@ -49,8 +45,8 @@ namespace ngraph
size_t ptx_size; size_t ptx_size;
NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptx_size)); NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptx_size));
char* ptx = new char[ptx_size]; char* ptx = new char[ptx_size];
NVRTC_SAFE_CALL(nvrtcGetPTX( NVRTC_SAFE_CALL(
prog, nvrtcGetPTX(prog,
ptx)); // Load the generated PTX and get a handle to the parent kernel. ptx)); // Load the generated PTX and get a handle to the parent kernel.
NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); // Destroy the program. NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); // Destroy the program.
...@@ -59,7 +55,4 @@ namespace ngraph ...@@ -59,7 +55,4 @@ namespace ngraph
CUDA_SAFE_CALL(cuModuleLoadDataEx(&module, ptx, 0, 0, 0)); CUDA_SAFE_CALL(cuModuleLoadDataEx(&module, ptx, 0, 0, 0));
CUDA_SAFE_CALL(cuModuleGetFunction(&function, module, name.c_str())); CUDA_SAFE_CALL(cuModuleGetFunction(&function, module, name.c_str()));
return std::make_shared<CUfunction>(function); return std::make_shared<CUfunction>(function);
}
}
}
} }
...@@ -26,40 +26,31 @@ ...@@ -26,40 +26,31 @@
static const std::string s_output_dir = "gpu_codegen"; static const std::string s_output_dir = "gpu_codegen";
namespace ngraph using namespace ngraph;
runtime::gpu::CudaFunctionPool& runtime::gpu::CudaFunctionPool::instance()
{ {
namespace runtime
{
namespace gpu
{
CudaFunctionPool& CudaFunctionPool::instance()
{
static CudaFunctionPool pool; static CudaFunctionPool pool;
return pool; return pool;
} }
void CudaFunctionPool::set(const std::string& name, const std::string& kernel) void runtime::gpu::CudaFunctionPool::set(const std::string& name, const std::string& kernel)
{ {
const char* opts[] = {"--gpu-architecture=compute_35", const char* opts[] = {"--gpu-architecture=compute_35", "--relocatable-device-code=true"};
"--relocatable-device-code=true"};
std::string filename = std::string filename =
file_util::path_join(s_output_dir, "cuda_kernel_" + name + "_codegen.cu"); file_util::path_join(s_output_dir, "cuda_kernel_" + name + "_codegen.cu");
std::ofstream out(filename); std::ofstream out(filename);
out << kernel; out << kernel;
out.close(); out.close();
m_function_map.insert( m_function_map.insert({name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts)});
{name, CudaFunctionBuilder::get("cuda_" + name, kernel, 2, opts)}); }
}
std::shared_ptr<CUfunction> CudaFunctionPool::get(const std::string& name) std::shared_ptr<CUfunction> runtime::gpu::CudaFunctionPool::get(const std::string& name)
{ {
auto it = m_function_map.find(name); auto it = m_function_map.find(name);
if (it != m_function_map.end()) if (it != m_function_map.end())
{ {
return (*it).second; return (*it).second;
} }
return nullptr; return nullptr;
}
}
}
} }
...@@ -16,18 +16,14 @@ ...@@ -16,18 +16,14 @@
#include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_builder.hpp"
#include "ngraph/codegen/code_writer.hpp" #include "ngraph/codegen/code_writer.hpp"
namespace ngraph using namespace ngraph;
{
namespace runtime void runtime::gpu::CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& writer,
{
namespace gpu
{
void CudaKernelBuilder::get_elementwise_op(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::string& data_type, const std::string& data_type,
const std::string& op, const std::string& op,
const size_t& num_inputs) const size_t& num_inputs)
{ {
writer << "extern \"C\" __global__ void cuda_" << name << "("; writer << "extern \"C\" __global__ void cuda_" << name << "(";
for (size_t i = 0; i < num_inputs; i++) for (size_t i = 0; i < num_inputs; i++)
{ {
...@@ -57,14 +53,14 @@ namespace ngraph ...@@ -57,14 +53,14 @@ namespace ngraph
writer << "}\n"; writer << "}\n";
return; return;
} }
void CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer, void runtime::gpu::CudaKernelBuilder::get_device_helper(codegen::CodeWriter& writer,
const std::string& name, const std::string& name,
const std::string& data_type, const std::string& data_type,
const std::string& math_kernel, const std::string& math_kernel,
const size_t& num_inputs) const size_t& num_inputs)
{ {
if (math_kernel.size()) if (math_kernel.size())
{ {
writer << "__device__ " << data_type << " " << name << "("; writer << "__device__ " << data_type << " " << name << "(";
...@@ -83,7 +79,4 @@ namespace ngraph ...@@ -83,7 +79,4 @@ namespace ngraph
writer << "}\n"; writer << "}\n";
} }
return; return;
}
}
}
} }
...@@ -20,15 +20,10 @@ ...@@ -20,15 +20,10 @@
#include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_emitters.hpp"
#include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp" #include "ngraph/runtime/gpu/gpu_cuda_kernel_ops.hpp"
namespace ngraph using namespace ngraph;
{ void runtime::gpu::emit_broadcast(
namespace runtime
{
namespace gpu
{
void emit_broadcast(
void* in, void* out, size_t repeat_size, size_t repeat_times, size_t count) void* in, void* out, size_t repeat_size, size_t repeat_times, size_t count)
{ {
std::string name = "broadcast"; std::string name = "broadcast";
// Create an instance of nvrtcProgram with the code string. // Create an instance of nvrtcProgram with the code string.
if (CudaFunctionPool::instance().get(name) == nullptr) if (CudaFunctionPool::instance().get(name) == nullptr)
...@@ -38,8 +33,9 @@ namespace ngraph ...@@ -38,8 +33,9 @@ namespace ngraph
kernel = R"( kernel = R"(
extern "C" __global__ extern "C" __global__
void cuda_)" + name + "(" + data_type + void cuda_)" + name +
"* in, " + data_type + "* out, size_t m, size_t k, size_t n)\n" + R"( "(" + data_type + "* in, " + data_type + "* out, size_t m, size_t k, size_t n)\n" +
R"(
{ {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
...@@ -53,12 +49,12 @@ void cuda_)" + name + "(" + data_type + ...@@ -53,12 +49,12 @@ void cuda_)" + name + "(" + data_type +
//convert runtime ptr to driver api ptr //convert runtime ptr to driver api ptr
CUdeviceptr d_ptr_in, d_ptr_out; CUdeviceptr d_ptr_in, d_ptr_out;
d_ptr_in = (CUdeviceptr)in; d_ptr_in = CUdeviceptr(in);
d_ptr_out = (CUdeviceptr)out; d_ptr_out = CUdeviceptr(out);
void* args_list[] = {&d_ptr_in, &d_ptr_out, &repeat_size, &repeat_times, &count}; void* args_list[] = {&d_ptr_in, &d_ptr_out, &repeat_size, &repeat_times, &count};
CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name).get(), CUDA_SAFE_CALL(cuLaunchKernel(*CudaFunctionPool::instance().get(name).get(),
count, static_cast<unsigned int>(count),
1, 1,
1, // grid dim 1, // grid dim
1, 1,
...@@ -69,7 +65,4 @@ void cuda_)" + name + "(" + data_type + ...@@ -69,7 +65,4 @@ void cuda_)" + name + "(" + data_type +
args_list, args_list,
0)); // arguments 0)); // arguments
CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output. CUDA_SAFE_CALL(cuCtxSynchronize()); // Retrieve and print output.
}
}
}
} }
...@@ -518,15 +518,11 @@ cudnnSetOpTensorDescriptor(opTensorDesc, ...@@ -518,15 +518,11 @@ cudnnSetOpTensorDescriptor(opTensorDesc,
writer.indent++; writer.indent++;
auto arg_shape = args[0].get_shape(); auto arg_shape = args[0].get_shape();
auto arg_rank = arg_shape.size(); auto arg_rank = arg_shape.size();
auto result_shape = out[0].get_shape(); auto result_shape = out[0].get_shape();
auto& result_element_type = out[0].get_element_type();
auto input_order = reshape->get_input_order(); auto input_order = reshape->get_input_order();
bool same_layout = is_sorted(input_order.begin(), input_order.end()); bool same_layout = is_sorted(input_order.begin(), input_order.end());
size_t result_shape_product = 1; size_t result_shape_product = 1;
for (auto i : result_shape) for (auto i : result_shape)
{ {
result_shape_product *= i; result_shape_product *= i;
......
...@@ -114,6 +114,7 @@ ...@@ -114,6 +114,7 @@
#include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp" #include "ngraph/runtime/gpu/gpu_kernel_emitters.hpp"
using namespace std; using namespace std;
using namespace ngraph;
static const string s_output_dir = "gpu_codegen"; static const string s_output_dir = "gpu_codegen";
...@@ -159,110 +160,104 @@ static StaticInitializers s_static_initializers; ...@@ -159,110 +160,104 @@ static StaticInitializers s_static_initializers;
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
namespace ngraph static const runtime::gpu::OpMap dispatcher{
{ {TI(ngraph::op::Add), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Add>},
namespace runtime {TI(ngraph::op::Dot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Dot>},
{ {TI(ngraph::op::Multiply), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Multiply>},
namespace gpu {TI(ngraph::op::Parameter), &runtime::gpu::GPU_Emitter::nop},
{ {TI(ngraph::op::Abs), &runtime::gpu::GPU_Emitter::EmitElementwise},
static const OpMap dispatcher{ {TI(ngraph::op::Concat), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Concat>},
{TI(ngraph::op::Add), &GPU_Emitter::emit<ngraph::op::Add>}, {TI(ngraph::op::Divide), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Dot), &GPU_Emitter::emit<ngraph::op::Dot>}, {TI(ngraph::op::Equal), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Equal>},
{TI(ngraph::op::Multiply), &GPU_Emitter::emit<ngraph::op::Multiply>},
{TI(ngraph::op::Parameter), &GPU_Emitter::nop},
{TI(ngraph::op::Abs), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Concat), &GPU_Emitter::emit<ngraph::op::Concat>},
{TI(ngraph::op::Divide), &GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Equal), &GPU_Emitter::emit<ngraph::op::Equal>},
{TI(ngraph::op::GetOutputElement), {TI(ngraph::op::GetOutputElement),
&GPU_Emitter::emit<ngraph::op::GetOutputElement>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::GetOutputElement>},
{TI(ngraph::op::Greater), &GPU_Emitter::emit<ngraph::op::Greater>}, {TI(ngraph::op::Greater), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Greater>},
{TI(ngraph::op::GreaterEq), &GPU_Emitter::emit<ngraph::op::GreaterEq>}, {TI(ngraph::op::GreaterEq), &runtime::gpu::GPU_Emitter::emit<ngraph::op::GreaterEq>},
{TI(ngraph::op::Less), &GPU_Emitter::emit<ngraph::op::Less>}, {TI(ngraph::op::Less), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Less>},
{TI(ngraph::op::LessEq), &GPU_Emitter::emit<ngraph::op::LessEq>}, {TI(ngraph::op::LessEq), &runtime::gpu::GPU_Emitter::emit<ngraph::op::LessEq>},
{TI(ngraph::op::Log), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Log), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Maximum), &GPU_Emitter::emit<ngraph::op::Maximum>}, {TI(ngraph::op::Maximum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Maximum>},
{TI(ngraph::op::Minimum), &GPU_Emitter::emit<ngraph::op::Minimum>}, {TI(ngraph::op::Minimum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Minimum>},
{TI(ngraph::op::Negative), &GPU_Emitter::emit<ngraph::op::Negative>}, {TI(ngraph::op::Negative), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Negative>},
{TI(ngraph::op::NotEqual), &GPU_Emitter::emit<ngraph::op::NotEqual>}, {TI(ngraph::op::NotEqual), &runtime::gpu::GPU_Emitter::emit<ngraph::op::NotEqual>},
{TI(ngraph::op::Power), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Power), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Select), &GPU_Emitter::emit<ngraph::op::Select>}, {TI(ngraph::op::Select), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Select>},
{TI(ngraph::op::Subtract), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Subtract), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Broadcast), &GPU_Emitter::emit<ngraph::op::Broadcast>}, {TI(ngraph::op::Broadcast), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Broadcast>},
{TI(ngraph::op::Convert), &GPU_Emitter::emit<ngraph::op::Convert>}, {TI(ngraph::op::Convert), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Convert>},
{TI(ngraph::op::Constant), &GPU_Emitter::emit<ngraph::op::Constant>}, {TI(ngraph::op::Constant), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Constant>},
{TI(ngraph::op::Reshape), &GPU_Emitter::emit<ngraph::op::Reshape>}, {TI(ngraph::op::Reshape), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reshape>},
{TI(ngraph::op::FunctionCall), &GPU_Emitter::emit<ngraph::op::FunctionCall>}, {TI(ngraph::op::FunctionCall), &runtime::gpu::GPU_Emitter::emit<ngraph::op::FunctionCall>},
{TI(ngraph::op::Reduce), &GPU_Emitter::emit<ngraph::op::Reduce>}, {TI(ngraph::op::Reduce), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reduce>},
{TI(ngraph::op::Sign), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Sign), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Slice), &GPU_Emitter::emit<ngraph::op::Slice>}, {TI(ngraph::op::Slice), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Slice>},
{TI(ngraph::op::Sum), &GPU_Emitter::emit<ngraph::op::Sum>}, {TI(ngraph::op::Sum), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Sum>},
{TI(ngraph::op::Exp), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Exp), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sin), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Sin), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sinh), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Sinh), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Cos), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Cos), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Cosh), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Cosh), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Tan), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Tan), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Tanh), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Tanh), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Asin), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Asin), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Acos), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Acos), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Atan), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Atan), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::ReplaceSlice), &GPU_Emitter::emit<ngraph::op::ReplaceSlice>}, {TI(ngraph::op::ReplaceSlice), &runtime::gpu::GPU_Emitter::emit<ngraph::op::ReplaceSlice>},
{TI(ngraph::op::OneHot), &GPU_Emitter::emit<ngraph::op::OneHot>}, {TI(ngraph::op::OneHot), &runtime::gpu::GPU_Emitter::emit<ngraph::op::OneHot>},
{TI(ngraph::op::Floor), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Floor), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Ceiling), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Ceiling), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::Sqrt), &GPU_Emitter::emit<ngraph::op::Sqrt>}, {TI(ngraph::op::Sqrt), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Sqrt>},
{TI(ngraph::op::Convolution), &GPU_Emitter::emit<ngraph::op::Convolution>}, {TI(ngraph::op::Convolution), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Convolution>},
{TI(ngraph::op::ConvolutionBackpropFilters), {TI(ngraph::op::ConvolutionBackpropFilters),
&GPU_Emitter::emit<ngraph::op::ConvolutionBackpropFilters>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::ConvolutionBackpropFilters>},
{TI(ngraph::op::ConvolutionBackpropData), {TI(ngraph::op::ConvolutionBackpropData),
&GPU_Emitter::emit<ngraph::op::ConvolutionBackpropData>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::ConvolutionBackpropData>},
{TI(ngraph::op::Not), &GPU_Emitter::EmitElementwise}, {TI(ngraph::op::Not), &runtime::gpu::GPU_Emitter::EmitElementwise},
{TI(ngraph::op::MaxPool), &GPU_Emitter::emit<ngraph::op::MaxPool>}, {TI(ngraph::op::MaxPool), &runtime::gpu::GPU_Emitter::emit<ngraph::op::MaxPool>},
{TI(ngraph::op::Reverse), &GPU_Emitter::emit<ngraph::op::Reverse>}, {TI(ngraph::op::Reverse), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Reverse>},
{TI(ngraph::op::Result), &GPU_Emitter::emit<ngraph::op::Result>}, {TI(ngraph::op::Result), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Result>},
{TI(ngraph::op::ReduceWindow), &GPU_Emitter::emit<ngraph::op::ReduceWindow>}, {TI(ngraph::op::ReduceWindow), &runtime::gpu::GPU_Emitter::emit<ngraph::op::ReduceWindow>},
{TI(ngraph::op::SelectAndScatter), {TI(ngraph::op::SelectAndScatter),
&GPU_Emitter::emit<ngraph::op::SelectAndScatter>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::SelectAndScatter>},
{TI(ngraph::op::AvgPool), &GPU_Emitter::emit<ngraph::op::AvgPool>}, {TI(ngraph::op::AvgPool), &runtime::gpu::GPU_Emitter::emit<ngraph::op::AvgPool>},
{TI(ngraph::op::AvgPoolBackprop), &GPU_Emitter::emit<ngraph::op::AvgPoolBackprop>}, {TI(ngraph::op::AvgPoolBackprop),
{TI(ngraph::op::Pad), &GPU_Emitter::emit<ngraph::op::Pad>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::AvgPoolBackprop>},
{TI(ngraph::op::BatchNorm), &GPU_Emitter::emit<ngraph::op::BatchNorm>}, {TI(ngraph::op::Pad), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Pad>},
{TI(ngraph::op::BatchNorm), &runtime::gpu::GPU_Emitter::emit<ngraph::op::BatchNorm>},
{TI(ngraph::op::BatchNormBackprop), {TI(ngraph::op::BatchNormBackprop),
&GPU_Emitter::emit<ngraph::op::BatchNormBackprop>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::BatchNormBackprop>},
{TI(ngraph::op::MaxPoolBackprop), &GPU_Emitter::emit<ngraph::op::MaxPoolBackprop>}, {TI(ngraph::op::MaxPoolBackprop),
{TI(ngraph::op::Product), &GPU_Emitter::emit<ngraph::op::Product>}, &runtime::gpu::GPU_Emitter::emit<ngraph::op::MaxPoolBackprop>},
{TI(ngraph::op::Max), &GPU_Emitter::emit<ngraph::op::Max>}, {TI(ngraph::op::Product), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Product>},
{TI(ngraph::op::Min), &GPU_Emitter::emit<ngraph::op::Min>}, {TI(ngraph::op::Max), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Max>},
{TI(ngraph::op::Relu), &GPU_Emitter::emit<ngraph::op::Relu>}, {TI(ngraph::op::Min), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Min>},
{TI(ngraph::op::ReluBackprop), &GPU_Emitter::emit<ngraph::op::ReluBackprop>}, {TI(ngraph::op::Relu), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Relu>},
{TI(ngraph::op::Softmax), &GPU_Emitter::emit<ngraph::op::Softmax>}, {TI(ngraph::op::ReluBackprop), &runtime::gpu::GPU_Emitter::emit<ngraph::op::ReluBackprop>},
}; {TI(ngraph::op::Softmax), &runtime::gpu::GPU_Emitter::emit<ngraph::op::Softmax>},
};
GPU_ExternalFunction::GPU_ExternalFunction(const shared_ptr<ngraph::Function>& function,
bool release_function) runtime::gpu::GPU_ExternalFunction::GPU_ExternalFunction(
const shared_ptr<ngraph::Function>& function, bool release_function)
: ngraph::runtime::ExternalFunction(function, release_function) : ngraph::runtime::ExternalFunction(function, release_function)
, m_compiled_function(nullptr) , m_compiled_function(nullptr)
, m_emit_timing(std::getenv("NGRAPH_GPU_EMIT_TIMING") != nullptr) , m_emit_timing(std::getenv("NGRAPH_GPU_EMIT_TIMING") != nullptr)
{ {
} }
void GPU_ExternalFunction::compile() void runtime::gpu::GPU_ExternalFunction::compile()
{ {
if (m_is_compiled) if (m_is_compiled)
{ {
return; return;
} }
string function_name = m_function->get_name(); string function_name = m_function->get_name();
string dump_filename = string dump_filename = file_util::path_join(s_output_dir, function_name + "_ops.txt");
file_util::path_join(s_output_dir, function_name + "_ops.txt");
pass::Manager pass_manager; pass::Manager pass_manager;
// pass_manager.register_pass<pass::TopologicalSort>(); // pass_manager.register_pass<pass::TopologicalSort>();
// For now, just make everyone row-major. // For now, just make everyone row-major.
pass_manager pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(64); pass_manager.register_pass<pass::MemoryLayout>(64);
pass_manager.register_pass<pass::DumpSorted>(dump_filename); pass_manager.register_pass<pass::DumpSorted>(dump_filename);
...@@ -308,8 +303,7 @@ using namespace std; ...@@ -308,8 +303,7 @@ using namespace std;
{ {
writer << "// Declare debug timers\n"; writer << "// Declare debug timers\n";
vector<string> names; vector<string> names;
for (shared_ptr<Function> current_function : for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
pass_manager.get_state().get_functions())
{ {
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
{ {
...@@ -323,8 +317,8 @@ using namespace std; ...@@ -323,8 +317,8 @@ using namespace std;
{ {
writer << "ngraph::stopwatch timer_" << s << ";\n"; writer << "ngraph::stopwatch timer_" << s << ";\n";
} }
writer << "extern \"C\" size_t get_debug_timer_count() { return " writer << "extern \"C\" size_t get_debug_timer_count() { return " << names.size()
<< names.size() << "; }\n"; << "; }\n";
writer << "extern \"C\" const char* get_debug_timer_name(size_t index)\n"; writer << "extern \"C\" const char* get_debug_timer_name(size_t index)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
...@@ -340,8 +334,7 @@ using namespace std; ...@@ -340,8 +334,7 @@ using namespace std;
writer << "return rc;\n"; writer << "return rc;\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
writer writer << "extern \"C\" const size_t get_debug_timer_microseconds(size_t index)\n";
<< "extern \"C\" const size_t get_debug_timer_microseconds(size_t index)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
writer << "size_t rc;\n"; writer << "size_t rc;\n";
...@@ -357,8 +350,7 @@ using namespace std; ...@@ -357,8 +350,7 @@ using namespace std;
writer << "return rc;\n"; writer << "return rc;\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
writer writer << "extern \"C\" const size_t get_debug_timer_call_count(size_t index)\n";
<< "extern \"C\" const size_t get_debug_timer_call_count(size_t index)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
writer << "size_t rc;\n"; writer << "size_t rc;\n";
...@@ -366,8 +358,7 @@ using namespace std; ...@@ -366,8 +358,7 @@ using namespace std;
writer << "{\n"; writer << "{\n";
for (size_t i = 0; i < names.size(); i++) for (size_t i = 0; i < names.size(); i++)
{ {
writer << "case " << i << ": rc = timer_" << names[i] writer << "case " << i << ": rc = timer_" << names[i] << ".get_call_count(); break;\n";
<< ".get_call_count(); break;\n";
} }
writer << "default: rc = 0;\n"; writer << "default: rc = 0;\n";
writer << "}\n"; writer << "}\n";
...@@ -383,31 +374,26 @@ using namespace std; ...@@ -383,31 +374,26 @@ using namespace std;
writer << "void *__dso_handle = 0;\n\n"; writer << "void *__dso_handle = 0;\n\n";
writer << "// Declare all constants\n"; writer << "// Declare all constants\n";
for (shared_ptr<Function> current_function : for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
pass_manager.get_state().get_functions())
{ {
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
{ {
const op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get()); const op::Constant* c = dynamic_cast<ngraph::op::Constant*>(node.get());
if (c) if (c)
{ {
shared_ptr<descriptor::TensorView> tv = shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
node->get_outputs()[0].get_tensor_view();
auto c_value_strings = c->get_value_strings(); auto c_value_strings = c->get_value_strings();
writer << "static " writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " "
<< tv->get_tensor().get_element_type().c_type_string() << " " << tv->get_tensor().get_name() << "_cpu[" << c_value_strings.size()
<< tv->get_tensor().get_name() << "_cpu[" << "] =\n";
<< c_value_strings.size() << "] =\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
writer << emit_string_array(c_value_strings, 100 - writer.indent * 4); writer << emit_string_array(c_value_strings, 100 - writer.indent * 4);
writer.indent--; writer.indent--;
writer << "\n};\n\n"; writer << "\n};\n\n";
writer << "static " writer << "static " << tv->get_tensor().get_element_type().c_type_string() << " *"
<< tv->get_tensor().get_element_type().c_type_string() << " *"
<< tv->get_tensor().get_name() << ";\n"; << tv->get_tensor().get_name() << ";\n";
m_variable_name_map[tv->get_tensor().get_name()] = m_variable_name_map[tv->get_tensor().get_name()] = tv->get_tensor().get_name();
tv->get_tensor().get_name();
} }
} }
} }
...@@ -415,8 +401,7 @@ using namespace std; ...@@ -415,8 +401,7 @@ using namespace std;
writer << "// Declare all functions\n"; writer << "// Declare all functions\n";
for (shared_ptr<Function> f : pass_manager.get_state().get_functions()) for (shared_ptr<Function> f : pass_manager.get_state().get_functions())
{ {
writer << "extern \"C\" void " << f->get_name() writer << "extern \"C\" void " << f->get_name() << "(void** inputs, void** outputs, "
<< "(void** inputs, void** outputs, "
"cublasHandle_t& cublas_handle, " "cublasHandle_t& cublas_handle, "
"cudnnHandle_t& cudnn_handle);\n"; "cudnnHandle_t& cudnn_handle);\n";
} }
...@@ -424,12 +409,8 @@ using namespace std; ...@@ -424,12 +409,8 @@ using namespace std;
writer << "\n"; writer << "\n";
unordered_map<Node*, string> match_functions; unordered_map<Node*, string> match_functions;
for (shared_ptr<Function> current_function : for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
pass_manager.get_state().get_functions())
{ {
bool temporaries_used = false;
size_t worst_case_tmp_size = 0;
set<string> output_names; set<string> output_names;
for (shared_ptr<Node> op : current_function->get_results()) for (shared_ptr<Node> op : current_function->get_results())
{ {
...@@ -454,18 +435,6 @@ using namespace std; ...@@ -454,18 +435,6 @@ using namespace std;
continue; continue;
} }
string match_function_name; string match_function_name;
for (size_t j = i + 1; j < op_list.size(); j++)
{
if (0) //op_list[i]->is_functionally_identical(*op_list[j]))
{
if (match_function_name.empty())
{
match_function_name = "func_" + op_list[i]->get_name();
match_functions.insert({op_list[i].get(), match_function_name});
}
match_functions.insert({op_list[j].get(), match_function_name});
}
}
if (!match_function_name.empty()) if (!match_function_name.empty())
{ {
writer << "static void " << match_function_name << "("; writer << "static void " << match_function_name << "(";
...@@ -518,8 +487,7 @@ using namespace std; ...@@ -518,8 +487,7 @@ using namespace std;
} }
} }
for (shared_ptr<Function> current_function : for (shared_ptr<Function> current_function : pass_manager.get_state().get_functions())
pass_manager.get_state().get_functions())
{ {
set<string> output_names; set<string> output_names;
for (shared_ptr<Node> op : current_function->get_results()) for (shared_ptr<Node> op : current_function->get_results())
...@@ -532,8 +500,7 @@ using namespace std; ...@@ -532,8 +500,7 @@ using namespace std;
{ {
if (dynamic_cast<ngraph::op::Constant*>(node.get())) if (dynamic_cast<ngraph::op::Constant*>(node.get()))
{ {
shared_ptr<descriptor::TensorView> tv = shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
node->get_outputs()[0].get_tensor_view();
constants.insert(tv.get()); constants.insert(tv.get());
} }
} }
...@@ -545,32 +512,27 @@ using namespace std; ...@@ -545,32 +512,27 @@ using namespace std;
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
for (shared_ptr<Function> current_function :
pass_manager.get_state().get_functions())
{
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
{ {
const op::Constant* c = dynamic_cast<op::Constant*>(node.get()); const op::Constant* c = dynamic_cast<op::Constant*>(node.get());
if (c) if (c)
{ {
shared_ptr<descriptor::TensorView> tv = shared_ptr<descriptor::TensorView> tv = node->get_outputs()[0].get_tensor_view();
node->get_outputs()[0].get_tensor_view();
writer << "if(" << tv->get_tensor().get_name() << " == NULL)\n"; writer << "if(" << tv->get_tensor().get_name() << " == NULL)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
writer << tv->get_tensor().get_name() << " = (" writer << tv->get_tensor().get_name() << " = ("
<< tv->get_tensor().get_element_type().c_type_string() << tv->get_tensor().get_element_type().c_type_string()
<< " *) ngraph::runtime::gpu::create_gpu_buffer(" << " *) runtime::gpu::create_gpu_buffer(" << tv->get_tensor().size()
<< tv->get_tensor().size() << ");\n"; << ");\n";
writer << "runtime::gpu::cuda_memcpyHtD(" writer << "runtime::gpu::cuda_memcpyHtD(" << tv->get_tensor().get_name() << ", "
<< tv->get_tensor().get_name() << ", " << tv->get_tensor().get_name() << "_cpu, " << tv->get_tensor().size()
<< tv->get_tensor().get_name() << "_cpu, " << ");\n";
<< tv->get_tensor().size() << ");\n";
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
} }
}
bool temporaries_used = false; bool temporaries_used = false;
size_t worst_case_tmp_size = 0; size_t worst_case_tmp_size = 0;
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
...@@ -599,8 +561,7 @@ using namespace std; ...@@ -599,8 +561,7 @@ using namespace std;
{ {
stringstream ss; stringstream ss;
ss << "((" << tensor->get_element_type().c_type_string() ss << "((" << tensor->get_element_type().c_type_string()
<< "*)((char *)pool_base_ptr + " << tensor->get_pool_offset() << "*)((char *)pool_base_ptr + " << tensor->get_pool_offset() << "))";
<< "))";
m_variable_name_map[tensor->get_name()] = ss.str(); m_variable_name_map[tensor->get_name()] = ss.str();
} }
} }
...@@ -608,15 +569,12 @@ using namespace std; ...@@ -608,15 +569,12 @@ using namespace std;
// Add inputs to the variable name map // Add inputs to the variable name map
size_t arg_index = 0; size_t arg_index = 0;
for (shared_ptr<ngraph::op::Parameter> param : for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
current_function->get_parameters())
{ {
for (size_t i = 0; i < param->get_output_size(); ++i) for (size_t i = 0; i < param->get_output_size(); ++i)
{ {
shared_ptr<descriptor::TensorView> tv = shared_ptr<descriptor::TensorView> tv = param->get_output_tensor_view(i);
param->get_output_tensor_view(i); const element::Type& et = tv->get_tensor_view_type()->get_element_type();
const element::Type& et =
tv->get_tensor_view_type()->get_element_type();
string type = et.c_type_string(); string type = et.c_type_string();
stringstream ss; stringstream ss;
ss << "((" << type << "*)(inputs[" << arg_index << "]))"; ss << "((" << type << "*)(inputs[" << arg_index << "]))";
...@@ -650,8 +608,7 @@ using namespace std; ...@@ -650,8 +608,7 @@ using namespace std;
shared_ptr<descriptor::TensorView> tv = op->get_output_tensor_view(); shared_ptr<descriptor::TensorView> tv = op->get_output_tensor_view();
const element::Type& et = tv->get_tensor_view_type()->get_element_type(); const element::Type& et = tv->get_tensor_view_type()->get_element_type();
bool parameter_as_output = false; bool parameter_as_output = false;
for (shared_ptr<ngraph::op::Parameter> param : for (shared_ptr<ngraph::op::Parameter> param : current_function->get_parameters())
current_function->get_parameters())
{ {
for (const descriptor::Output& pout : param->get_outputs()) for (const descriptor::Output& pout : param->get_outputs())
{ {
...@@ -659,10 +616,8 @@ using namespace std; ...@@ -659,10 +616,8 @@ using namespace std;
if (tv == ptv) if (tv == ptv)
{ {
parameter_as_output = true; parameter_as_output = true;
writer writer << "ngraph::runtime::gpu::cuda_memcpyDtD(reinterpret_cast<"
<< "ngraph::runtime::gpu::cuda_memcpyDtD(reinterpret_cast<" << et.c_type_string() << "*>(outputs[" << output_index << "]), "
<< et.c_type_string() << "*>(outputs[" << output_index
<< "]), "
<< m_variable_name_map[ptv->get_tensor().get_name()] << ", " << m_variable_name_map[ptv->get_tensor().get_name()] << ", "
<< ptv->get_tensor().size() << ");\n"; << ptv->get_tensor().size() << ");\n";
break; break;
...@@ -673,9 +628,9 @@ using namespace std; ...@@ -673,9 +628,9 @@ using namespace std;
{ {
if (contains(constants, tv.get())) if (contains(constants, tv.get()))
{ {
writer << "ngraph::runtime::gpu::cuda_memcpyHtD(outputs[" writer << "ngraph::runtime::gpu::cuda_memcpyHtD(outputs[" << output_index
<< output_index << "], " << tv->get_tensor().get_name() << "], " << tv->get_tensor().get_name() << ", "
<< ", " << tv->get_tensor().size() << ");\n"; << tv->get_tensor().size() << ");\n";
} }
else else
{ {
...@@ -690,29 +645,27 @@ using namespace std; ...@@ -690,29 +645,27 @@ using namespace std;
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
{ {
auto& n = auto& n = *node; // Work around a compiler warning (*node inside typeid may have effects
*node; // Work around a compiler warning (*node inside typeid may have effects
// with shared pointers, which is fine here but clang doesn't like it.) // with shared pointers, which is fine here but clang doesn't like it.)
auto handler = dispatcher.find(type_index(typeid(n))); auto handler = dispatcher.find(type_index(typeid(n)));
if (handler == dispatcher.end()) if (handler == dispatcher.end())
{ {
throw ngraph_error("Unhandled op during code generation : " + throw ngraph_error("Unhandled op during code generation : " + node->description());
node->description());
} }
vector<GPU_TensorViewWrapper> in; vector<GPU_TensorViewWrapper> in;
for (const descriptor::Input& input : node->get_inputs()) for (const descriptor::Input& input : node->get_inputs())
{ {
const descriptor::Output& output = input.get_output(); const descriptor::Output& output = input.get_output();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view(); shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
in.push_back(GPU_TensorViewWrapper( in.push_back(
tv, m_variable_name_map[tv->get_tensor().get_name()])); GPU_TensorViewWrapper(tv, m_variable_name_map[tv->get_tensor().get_name()]));
} }
vector<GPU_TensorViewWrapper> out; vector<GPU_TensorViewWrapper> out;
for (const descriptor::Output& output : node->get_outputs()) for (const descriptor::Output& output : node->get_outputs())
{ {
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view(); shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
out.push_back(GPU_TensorViewWrapper( out.push_back(
tv, m_variable_name_map[tv->get_tensor().get_name()])); GPU_TensorViewWrapper(tv, m_variable_name_map[tv->get_tensor().get_name()]));
} }
// Emit operation prologue // Emit operation prologue
...@@ -758,6 +711,7 @@ using namespace std; ...@@ -758,6 +711,7 @@ using namespace std;
} }
} }
} }
writer.indent--; writer.indent--;
// End generated function // End generated function
writer += "}\n\n"; writer += "}\n\n";
...@@ -765,8 +719,7 @@ using namespace std; ...@@ -765,8 +719,7 @@ using namespace std;
// TODO: Cleanup and make this a utility function // TODO: Cleanup and make this a utility function
file_util::make_directory(s_output_dir); file_util::make_directory(s_output_dir);
string filename = string filename = file_util::path_join(s_output_dir, function_name + "_codegen.cpp");
file_util::path_join(s_output_dir, function_name + "_codegen.cpp");
ofstream out(filename); ofstream out(filename);
string code = writer.get_code(); string code = writer.get_code();
out << code; out << code;
...@@ -785,8 +738,7 @@ using namespace std; ...@@ -785,8 +738,7 @@ using namespace std;
} }
m_execution_engine->add_module(codegen_module); m_execution_engine->add_module(codegen_module);
m_execution_engine->finalize(); m_execution_engine->finalize();
m_compiled_function = m_compiled_function = m_execution_engine->find_function<EntryPoint_t>(function_name);
m_execution_engine->find_function<EntryPoint_t>(function_name);
assert(m_compiled_function); assert(m_compiled_function);
m_is_compiled = true; m_is_compiled = true;
...@@ -794,13 +746,13 @@ using namespace std; ...@@ -794,13 +746,13 @@ using namespace std;
{ {
release_function(); release_function();
} }
} }
void GPU_ExternalFunction::handle_output_alias( void runtime::gpu::GPU_ExternalFunction::handle_output_alias(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
const Node& node, const Node& node,
const unordered_map<descriptor::TensorView*, vector<size_t>>& output_alias_map) const unordered_map<descriptor::TensorView*, vector<size_t>>& output_alias_map)
{ {
for (const descriptor::Output& output : node.get_outputs()) for (const descriptor::Output& output : node.get_outputs())
{ {
shared_ptr<descriptor::TensorView> otv = output.get_tensor_view(); shared_ptr<descriptor::TensorView> otv = output.get_tensor_view();
...@@ -816,44 +768,40 @@ using namespace std; ...@@ -816,44 +768,40 @@ using namespace std;
{ {
writer << "ngraph::runtime::gpu::cuda_memcpyDtD(static_cast<void*>(" writer << "ngraph::runtime::gpu::cuda_memcpyDtD(static_cast<void*>("
"outputs[" "outputs["
<< outputs[i] << "]), static_cast<void*>(outputs[" << outputs[i] << "]), static_cast<void*>(outputs[" << outputs[0]
<< outputs[0] << "]), " << otv->get_tensor().size() << "]), " << otv->get_tensor().size() << ");\n";
<< ");\n";
} }
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
} }
} }
} }
shared_ptr<ngraph::runtime::CallFrame> GPU_ExternalFunction::make_call_frame() shared_ptr<ngraph::runtime::CallFrame> runtime::gpu::GPU_ExternalFunction::make_call_frame()
{ {
if (!m_is_compiled) if (!m_is_compiled)
{ {
compile(); compile();
} }
return make_shared<GPU_CallFrame>(shared_from_this(), m_compiled_function); return make_shared<GPU_CallFrame>(shared_from_this(), m_compiled_function);
} }
void GPU_ExternalFunction::emit_debug_function_entry( void runtime::gpu::GPU_ExternalFunction::emit_debug_function_entry(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
Node* node, Node* node,
const std::vector<GPU_TensorViewWrapper>& in, const std::vector<GPU_TensorViewWrapper>& in,
const std::vector<GPU_TensorViewWrapper>& out) const std::vector<GPU_TensorViewWrapper>& out)
{ {
writer << "timer_" << node->get_name() << ".start();\n"; writer << "timer_" << node->get_name() << ".start();\n";
} }
void GPU_ExternalFunction::emit_debug_function_exit( void runtime::gpu::GPU_ExternalFunction::emit_debug_function_exit(
codegen::CodeWriter& writer, codegen::CodeWriter& writer,
Node* node, Node* node,
const std::vector<GPU_TensorViewWrapper>& in, const std::vector<GPU_TensorViewWrapper>& in,
const std::vector<GPU_TensorViewWrapper>& out) const std::vector<GPU_TensorViewWrapper>& out)
{ {
writer << "timer_" << node->get_name() << ".stop();\n"; writer << "timer_" << node->get_name() << ".stop();\n";
}
}
}
} }
...@@ -41,7 +41,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen ...@@ -41,7 +41,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
m_buffer_size = shape_size(shape) * element_type.size(); m_buffer_size = shape_size(shape) * element_type.size();
if (m_buffer_size > 0) if (m_buffer_size > 0)
{ {
cudaMalloc((void**)&m_allocated_buffer_pool, m_buffer_size); cudaMalloc(static_cast<void**>(&m_allocated_buffer_pool), m_buffer_size);
} }
} }
......
...@@ -50,7 +50,7 @@ void runtime::gpu::check_cuda_errors(CUresult err) ...@@ -50,7 +50,7 @@ void runtime::gpu::check_cuda_errors(CUresult err)
void* runtime::gpu::create_gpu_buffer(size_t buffer_size) void* runtime::gpu::create_gpu_buffer(size_t buffer_size)
{ {
void* allocated_buffer_pool; void* allocated_buffer_pool;
cudaMalloc((void**)&allocated_buffer_pool, buffer_size); cudaMalloc(static_cast<void**>(&allocated_buffer_pool), buffer_size);
return allocated_buffer_pool; return allocated_buffer_pool;
} }
......
...@@ -48,6 +48,8 @@ ...@@ -48,6 +48,8 @@
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/all_close.hpp" #include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/matcher.hpp" #include "util/matcher.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
...@@ -914,3 +916,47 @@ TEST(cpu_fusion, sigmoid_n1c1h4) ...@@ -914,3 +916,47 @@ TEST(cpu_fusion, sigmoid_n1c1h4)
vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f}; vector<float> expected{0.73105858f, 0.98201379f, 0.73105858f, 0.98201379f};
ASSERT_TRUE(read_vector<float>(result) == expected); ASSERT_TRUE(read_vector<float>(result) == expected);
} }
TEST(cpu_fusion, sigmoid_bprop_fusion)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/Graph_fprop_sigmoid.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
auto df = autodiff::backprop_function(func);
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(df);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
size_t ccg = count_ops_of_type<op::SigmoidBackprop>(df);
ASSERT_EQ(ccg, 1);
}
TEST(cpu_fusion, sigmoid_bprop_n1c1h4)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto delta = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto sigmoid_node = make_shared<op::SigmoidBackprop>(input, delta);
auto func = make_shared<Function>(sigmoid_node, op::ParameterVector{input, delta});
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(func);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
shared_ptr<runtime::TensorView> a =
backend->make_primary_tensor_view(element::f32, input->get_shape());
shared_ptr<runtime::TensorView> b =
backend->make_primary_tensor_view(element::f32, delta->get_shape());
shared_ptr<runtime::TensorView> result =
backend->make_primary_tensor_view(element::f32, input->get_shape());
vector<float> dataA{1.0f, 4.0f, 1.0f, 4.0f};
vector<float> dataB{1.0f, 1.0f, 1.0f, 1.0f};
copy_data(a, dataA);
copy_data(b, dataB);
cf->call({a, b}, {result});
vector<float> expected{0.196612f, 0.0176627f, 0.196612f, 0.0176627f};
EXPECT_TRUE(test::all_close(expected, read_vector<float>(result)));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment