Commit b8c23f9c authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Enable softmax test (#2333)

* [ONNX] Enable softmax test

* Update onnx_import.in.cpp

* change __expf to exp for accuracy
parent af5340d8
......@@ -311,7 +311,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
size_t reduce_rank)
{
auto stable_sum_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max);\n";
writer << "input_i = exp(input_i - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
......@@ -321,7 +321,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_op(codegen::CodeWriter& writer
auto max_lambda = [&]() { writer << "r_max = r_max > input_i ? r_max : input_i;\n"; };
auto divide_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max) / r_sum;\n";
writer << "input_i = exp(input_i - r_max) / r_sum;\n";
writer << "out[reduce_idx] = input_i;\n";
};
writer << runtime::gpu::nvrtc::helpers();
......@@ -537,7 +537,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_block_reduce_op(
};
auto stable_sum_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max);\n";
writer << "input_i = exp(input_i - r_max);\n";
writer << "y = input_i - c;\n";
writer << "t = r_sum + y;\n";
writer << "c = (t - r_sum) - y;\n";
......@@ -547,7 +547,7 @@ void runtime::gpu::CudaKernelBuilder::get_softmax_block_reduce_op(
auto max_lambda = [&]() { writer << "r_max = r_max > input_i ? r_max : input_i;\n"; };
auto divide_lambda = [&]() {
writer << "input_i = __expf(input_i - r_max) / r_sum;\n";
writer << "input_i = exp(input_i - r_max) / r_sum;\n";
writer << "out[input_idx] = input_i;\n";
};
......
......@@ -473,7 +473,7 @@ TEST(onnx_${BACKEND_NAME}, model_matmul)
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
}
TEST(onnx_${BACKEND_NAME}, DISABLED_model_softmax)
TEST(onnx_${BACKEND_NAME}, model_softmax)
{
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/softmax.onnx"));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment