Commit 78ff5784 authored by fenglei.tian's avatar fenglei.tian

add multiplu use cudnn

parent 1c74ad24
#./build/test/unit-test --gtest_filter=GPU.ab
./build/test/unit-test --gtest_filter=GPU.maximum
./build/test/unit-test --gtest_filter=GPU.minimum
./build/test/unit-test --gtest_filter=GPU.multiple*
#./build/test/unit-test --gtest_filter=GPU.abs
#./build/test/unit-test --gtest_filter=GPU.dot*
......@@ -63,8 +63,6 @@ void runtime::gpu::GPU_Emitter::EmitAdd(codegen::CodeWriter& writer,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
{
// TODO Assert arg0_shape[0] == arg1_shape[0]?
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "const float alpha = 1.0;\n";
......@@ -84,7 +82,6 @@ void runtime::gpu::GPU_Emitter::EmitAdd(codegen::CodeWriter& writer,
<< out[0].get_size() << ");\n";
writer.indent--;
writer << "}\n";
}
}
void runtime::gpu::GPU_Emitter::EmitConcat(codegen::CodeWriter& writer,
......@@ -260,27 +257,26 @@ void runtime::gpu::GPU_Emitter::EmitMaximum(codegen::CodeWriter& writer,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << " // " << n->get_name() << "\n";
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
cudnnTensorDescriptor_t descriptor;
(cudnnCreateTensorDescriptor(&descriptor));
(cudnnSetTensor4dDescriptor(descriptor,
cudnnCreateTensorDescriptor(&descriptor);
cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count));
/*image_width=*/count);
cudnnOpTensorDescriptor_t opTensorDesc;
(cudnnCreateOpTensorDescriptor(&opTensorDesc));
(cudnnSetOpTensorDescriptor(opTensorDesc,
cudnnCreateOpTensorDescriptor(&opTensorDesc);
cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MAX,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN));
CUDNN_NOT_PROPAGATE_NAN);
)";
writer << "cudnnOpTensor(cudnn_handle,"
......@@ -294,11 +290,6 @@ cudnnOpTensorDescriptor_t opTensorDesc;
<< "&beta,"
<< "descriptor,"
<< out[0].get_name() << ");\n";
writer += R"(
)";
writer.indent--;
writer << "}\n";
}
void runtime::gpu::GPU_Emitter::EmitMinimum(codegen::CodeWriter& writer,
......@@ -306,6 +297,40 @@ void runtime::gpu::GPU_Emitter::EmitMinimum(codegen::CodeWriter& writer,
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
writer << " // " << n->get_name() << "\n";
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
cudnnTensorDescriptor_t descriptor;
cudnnCreateTensorDescriptor(&descriptor);
cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count);
cudnnOpTensorDescriptor_t opTensorDesc;
cudnnCreateOpTensorDescriptor(&opTensorDesc);
cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MIN,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN);
)";
writer << "cudnnOpTensor(cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor,"
<< args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor,"
<< args[1].get_name() << ","
<< "&beta,"
<< "descriptor,"
<< out[0].get_name() << ");\n";
}
void runtime::gpu::GPU_Emitter::EmitNegative(
......@@ -476,34 +501,40 @@ void runtime::gpu::GPU_Emitter::EmitMultiply(
const vector<runtime::gpu::GPU_TensorViewWrapper>& args,
const vector<runtime::gpu::GPU_TensorViewWrapper>& out)
{
const Shape& arg0_shape = args[0].get_shape();
const Shape& arg1_shape = args[1].get_shape();
// Until we have EW kernel gen, use cuBLAS
// From https://stackoverflow.com/questions/7621520/element-wise-vector-vector-multiplication-in-bl as/7634831
writer << " // " << n->get_name() << "\n";
writer << "int count = " << out[0].get_size() << ";\n";
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
cudnnTensorDescriptor_t descriptor;
cudnnCreateTensorDescriptor(&descriptor);
cudnnSetTensor4dDescriptor(descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/1,
/*image_height=*/1,
/*image_width=*/count);
// clang-format off
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "static const float alpha = 1.0;\n";
writer << "static const float beta = 0.0;\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";;
writer << "cublasSsbmv("
<< "cublas_handle,"
<< "CUBLAS_FILL_MODE_LOWER," // Corresponds to FORTRAN "L"
<< out[0].get_size() << "," // N = input size
<< "0," // k = super-diagonal i.e. just use the diagonal of A
<< "&alpha," // Alpha
<< args[0].get_name() << "," // vec A (broadcast to a matrix)
<< "1," // LDA = 1
<< args[1].get_name() << "," // vector x
<< "1," // Stride x
<< "&beta," // beta
<< out[0].get_name() << "," // y
<< "1" // Stride y
<< ");\n";
writer.indent--;
writer << "}\n";
// clang-format on
cudnnOpTensorDescriptor_t opTensorDesc;
cudnnCreateOpTensorDescriptor(&opTensorDesc);
cudnnSetOpTensorDescriptor(opTensorDesc,
CUDNN_OP_TENSOR_MUL,
CUDNN_DATA_FLOAT,
CUDNN_NOT_PROPAGATE_NAN);
)";
writer << "cudnnOpTensor(cudnn_handle,"
<< "opTensorDesc,"
<< "&alpha1,"
<< "descriptor,"
<< args[0].get_name() << ","
<< "&alpha2,"
<< "descriptor,"
<< args[1].get_name() << ","
<< "&beta,"
<< "descriptor,"
<< out[0].get_name() << ");\n";
}
void runtime::gpu::GPU_Emitter::EmitExp(codegen::CodeWriter& writer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment