Commit 6f0c8190 authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

Fix Erf CPU Builder for int32 (#2773)

* - pass tensor through reference in the lambad closure of cpu reference_erf kernnel
- add Erf unit test case to verify codepath for int32 values

* fix clang errors
parent 03f13e4b
......@@ -58,8 +58,8 @@ namespace ngraph
std::function<decltype(runtime::cpu::kernel::reference_erf<float>)> kernel;
SELECT_KERNEL(
kernel, args[0].get_element_type(), runtime::cpu::kernel::reference_erf);
auto functor = [&, kernel, arg0_tensor, out0_tensor](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
......
......@@ -3103,7 +3103,7 @@ namespace ngraph
writer << "cpu::kernel::reference_erf<"
<< args[0].get_element_type().c_type_string() << ">("
<< args[0].get_name() << ", " << out[0].get_name() << ", "
<< ", " << element_count << ");\n";
<< element_count << ");\n";
}
writer.block_end();
}
......
......@@ -1191,7 +1191,7 @@ TEST(cpu_test, conv_negative_padding)
compare_backends(make_f(), make_f(), "CPU", "INTERPRETER");
}
TEST(cpu_test, guass_error_function_erf)
TEST(cpu_test, gauss_error_function_erf_float32)
{
auto make_function = []() -> std::shared_ptr<Function> {
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 4, 10, 6, 10});
......@@ -1219,3 +1219,34 @@ TEST(cpu_test, guass_error_function_erf)
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i)));
}
}
TEST(cpu_test, gauss_error_function_erf_int32)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto make_function = [&]() -> std::shared_ptr<Function> {
auto erf = make_shared<op::Erf>(A);
return make_shared<Function>(erf, ParameterVector{A});
};
auto backend = runtime::Backend::create("CPU");
auto cpu_f = make_function();
auto input_nd_array = test::NDArray<int, 2>({{45, 2}, {7, 9}});
auto expected_result_nd_array =
test::NDArray<int, 2>({{static_cast<int>(std::erf(45)), static_cast<int>(std::erf(2))},
{static_cast<int>(std::erf(7)), static_cast<int>(std::erf(9))}});
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::i32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::i32, shape);
copy_data(a, input_nd_array.get_vector());
auto handle = backend->compile(cpu_f);
handle->call_with_validate({result}, {a});
auto result_values = read_vector<int>(result);
auto expected_values = expected_result_nd_array.get_vector();
ASSERT_EQ(result_values, expected_values);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment