Commit 59ec0797 authored by Tristan Webb's avatar Tristan Webb

fixed hardcoded param

parent 8bf6b3ff
......@@ -295,7 +295,6 @@ void runtime::gpu::GPU_Emitter::EmitMaximum(codegen::CodeWriter& writer,
writer << "{ // " << n->get_name() << "\n";
writer.indent++;
writer << "int count = " << out[0].get_size() << ";\n";
// writer << "static const float beta = 0.0;\n";
writer << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST);\n";;
writer += R"(
float alpha1 = 1.0, alpha2 = 1.0, beta = 0;
......@@ -530,8 +529,7 @@ void runtime::gpu::GPU_Emitter::EmitMultiply(
writer << "cublasSsbmv("
<< "cublas_handle,"
<< "CUBLAS_FILL_MODE_LOWER," // Corresponds to FORTRAN "L"
// << arg0_shape[0] << "," // N = input size
<< "4," // N = input size
<< out[0].get_size() << "," // N = input size
<< "0," // k = super-diagonal i.e. just use the diagonal of A
<< "&alpha," // Alpha
<< args[0].get_name() << "," // vec A (broadcast to a matrix)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment