Commit 7c337e5d authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Store Constant arrays where data is constant as a single value (#2880)

* wip

* Add support for storing constant array of constant values as a single values that is automatically broadcast on deserialize

* revert some changes to serializer.cpp

* fix all_close_f to support nan and inf to allow for unit test

* update unit tests to pass for all_close_f update

* fix bug with i64

* address compile issues?

* change function name to be more accurate

* fix compiler error
parent ae352fa4
...@@ -34,14 +34,7 @@ string to_cpp_string(T value) ...@@ -34,14 +34,7 @@ string to_cpp_string(T value)
} }
else if (std::isinf(value)) else if (std::isinf(value))
{ {
if (value > 0) rc = (value > 0 ? "INFINITY" : "-INFINITY");
{
rc = "INFINITY";
}
else
{
rc = "-INFINITY";
}
} }
else else
{ {
...@@ -60,96 +53,93 @@ vector<string> op::Constant::get_value_strings() const ...@@ -60,96 +53,93 @@ vector<string> op::Constant::get_value_strings() const
{ {
vector<string> rc; vector<string> rc;
if (m_element_type == element::boolean) #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (get_element_type().get_type_enum())
{ {
case element::Type_t::boolean:
for (int value : get_vector<char>()) for (int value : get_vector<char>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::bf16) case element::Type_t::bf16:
{
float temp = 0;
for (bfloat16 value : get_vector<bfloat16>()) for (bfloat16 value : get_vector<bfloat16>())
{ {
temp = static_cast<float>(value); rc.push_back(to_cpp_string(static_cast<float>(value)));
rc.push_back(to_cpp_string(temp));
} }
} break;
else if (m_element_type == element::f32) case element::Type_t::f16:
for (float16 value : get_vector<float16>())
{ {
rc.push_back(to_cpp_string(static_cast<float>(value)));
}
break;
case element::Type_t::f32:
for (float value : get_vector<float>()) for (float value : get_vector<float>())
{ {
rc.push_back(to_cpp_string(value)); rc.push_back(to_cpp_string(value));
} }
} break;
else if (m_element_type == element::f64) case element::Type_t::f64:
{
for (double value : get_vector<double>()) for (double value : get_vector<double>())
{ {
rc.push_back(to_cpp_string(value)); rc.push_back(to_cpp_string(value));
} }
} break;
else if (m_element_type == element::i8) case element::Type_t::i8:
{
for (int value : get_vector<int8_t>()) for (int value : get_vector<int8_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::i16) case element::Type_t::i16:
{
for (int value : get_vector<int16_t>()) for (int value : get_vector<int16_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::i32) case element::Type_t::i32:
{
for (int32_t value : get_vector<int32_t>()) for (int32_t value : get_vector<int32_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::i64) case element::Type_t::i64:
{
for (int64_t value : get_vector<int64_t>()) for (int64_t value : get_vector<int64_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::u8) case element::Type_t::u8:
{
for (uint32_t value : get_vector<uint8_t>()) for (uint32_t value : get_vector<uint8_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::u16) case element::Type_t::u16:
{
for (uint32_t value : get_vector<uint16_t>()) for (uint32_t value : get_vector<uint16_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::u32) case element::Type_t::u32:
{
for (uint32_t value : get_vector<uint32_t>()) for (uint32_t value : get_vector<uint32_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
} break;
else if (m_element_type == element::u64) case element::Type_t::u64:
{
for (uint64_t value : get_vector<uint64_t>()) for (uint64_t value : get_vector<uint64_t>())
{ {
rc.push_back(to_string(value)); rc.push_back(to_string(value));
} }
break;
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type");
} }
else #pragma GCC diagnostic pop
{
throw runtime_error("unsupported type");
}
return rc; return rc;
} }
...@@ -160,6 +150,71 @@ shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) co ...@@ -160,6 +150,71 @@ shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) co
return make_shared<Constant>(m_element_type, m_shape, m_data->get_ptr()); return make_shared<Constant>(m_element_type, m_shape, m_data->get_ptr());
} }
template <typename T>
static bool test_bitwise_identical(const op::Constant* constant)
{
const size_t size = shape_size(constant->get_shape());
bool data_is_constant = true;
if (size > 0)
{
const T* data = constant->get_data_ptr<T>();
const T compare = data[0];
for (size_t i = 1; i < size; i++)
{
if (data[i] != compare)
{
data_is_constant = false;
break;
}
}
}
return data_is_constant;
}
bool op::Constant::are_all_data_elements_bitwise_identical() const
{
bool rc = false;
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
rc = test_bitwise_identical<uint8_t>(this);
break;
}
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::i16:
case element::Type_t::u16:
{
rc = test_bitwise_identical<uint16_t>(this);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
rc = test_bitwise_identical<uint32_t>(this);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
rc = test_bitwise_identical<uint64_t>(this);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
#pragma GCC diagnostic pop
return rc;
}
shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
{ {
return std::make_shared<op::Constant>(m_element_type, m_shape, m_data->get_ptr()); return std::make_shared<op::Constant>(m_element_type, m_shape, m_data->get_ptr());
......
...@@ -85,7 +85,7 @@ namespace ngraph ...@@ -85,7 +85,7 @@ namespace ngraph
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
values.size() == shape_size(m_shape), values.size() == shape_size(m_shape) || values.size() == 1,
"Did not get the expected number of literals for a constant of shape ", "Did not get the expected number of literals for a constant of shape ",
m_shape, m_shape,
" (got ", " (got ",
...@@ -94,8 +94,34 @@ namespace ngraph ...@@ -94,8 +94,34 @@ namespace ngraph
shape_size(m_shape), shape_size(m_shape),
"."); ".");
std::vector<double> dvalues = parse_string<double>(values); std::vector<std::string> tmp_values;
if (values.size() == 1 && shape_size(m_shape) != 1)
{
tmp_values = std::vector<std::string>(shape_size(m_shape), values[0]);
}
else
{
tmp_values = values;
}
if (type.is_integral())
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(tmp_values);
write_values(dvalues);
}
else
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(tmp_values);
write_values(dvalues);
}
}
else
{
std::vector<double> dvalues = parse_string<double>(tmp_values);
write_values(dvalues); write_values(dvalues);
}
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -185,6 +211,8 @@ namespace ngraph ...@@ -185,6 +211,8 @@ namespace ngraph
} }
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
protected: protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); } void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
Constant(const std::string& name, const NodeVector& args) Constant(const std::string& name, const NodeVector& args)
...@@ -222,58 +250,54 @@ namespace ngraph ...@@ -222,58 +250,54 @@ namespace ngraph
{ {
throw std::runtime_error("Constant initializer does not match shape"); throw std::runtime_error("Constant initializer does not match shape");
} }
if (target_type == element::boolean) #pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (target_type.get_type_enum())
{ {
case element::Type_t::boolean:
write_buffer<char, T>(target, source, target_element_count); write_buffer<char, T>(target, source, target_element_count);
} break;
else if (target_type == element::bf16) case element::Type_t::bf16:
{
write_buffer<bfloat16, T>(target, source, target_element_count); write_buffer<bfloat16, T>(target, source, target_element_count);
} break;
else if (target_type == element::f32) case element::Type_t::f16:
{ write_buffer<float16, T>(target, source, target_element_count);
break;
case element::Type_t::f32:
write_buffer<float, T>(target, source, target_element_count); write_buffer<float, T>(target, source, target_element_count);
} break;
else if (target_type == element::f64) case element::Type_t::f64:
{
write_buffer<double, T>(target, source, target_element_count); write_buffer<double, T>(target, source, target_element_count);
} break;
else if (target_type == element::i8) case element::Type_t::i8:
{
write_buffer<int8_t, T>(target, source, target_element_count); write_buffer<int8_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::i16) case element::Type_t::i16:
{
write_buffer<int16_t, T>(target, source, target_element_count); write_buffer<int16_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::i32) case element::Type_t::i32:
{
write_buffer<int32_t, T>(target, source, target_element_count); write_buffer<int32_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::i64) case element::Type_t::i64:
{
write_buffer<int64_t, T>(target, source, target_element_count); write_buffer<int64_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::u8) case element::Type_t::u8:
{
write_buffer<uint8_t, T>(target, source, target_element_count); write_buffer<uint8_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::u16) case element::Type_t::u16:
{
write_buffer<uint16_t, T>(target, source, target_element_count); write_buffer<uint16_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::u32) case element::Type_t::u32:
{
write_buffer<uint32_t, T>(target, source, target_element_count); write_buffer<uint32_t, T>(target, source, target_element_count);
} break;
else if (target_type == element::u64) case element::Type_t::u64:
{
write_buffer<uint64_t, T>(target, source, target_element_count); write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
} }
else #pragma GCC diagnostic pop
{
throw std::runtime_error("unsupported type");
}
} }
static constexpr size_t host_alignment() { return 64; } static constexpr size_t host_alignment() { return 64; }
......
...@@ -22,24 +22,17 @@ ...@@ -22,24 +22,17 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template <typename T> bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
static bool is_data_constant(shared_ptr<op::Constant> constant)
{ {
const size_t size = shape_size(constant->get_shape()); const size_t minimum_size_of_interest = 32;
bool data_is_constant = true; bool modified = false;
if (size > 0) if (node->description() == "Constant")
{ {
const T* data = constant->get_data_ptr<T>(); auto constant = static_pointer_cast<op::Constant>(node);
const T compare = data[0]; size_t size = shape_size(constant->get_shape());
for (size_t i = 1; i < size; i++) if (size > minimum_size_of_interest)
{ {
if (data[i] != compare) if (constant->are_all_data_elements_bitwise_identical())
{
data_is_constant = false;
break;
}
}
if (data_is_constant)
{ {
auto scalar_constant = make_shared<op::Constant>( auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr()); constant->get_element_type(), Shape{}, constant->get_data_ptr());
...@@ -53,53 +46,6 @@ static bool is_data_constant(shared_ptr<op::Constant> constant) ...@@ -53,53 +46,6 @@ static bool is_data_constant(shared_ptr<op::Constant> constant)
replace_node(constant, broadcast); replace_node(constant, broadcast);
} }
} }
return data_is_constant;
}
bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
{
const size_t minimum_size_of_interest = 32;
bool modified = false;
if (node->description() == "Constant")
{
auto constant = static_pointer_cast<op::Constant>(node);
size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest)
{
switch (constant->get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
modified = is_data_constant<uint8_t>(constant);
break;
}
case element::Type_t::bf16:
case element::Type_t::i16:
case element::Type_t::u16:
{
modified = is_data_constant<uint16_t>(constant);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
modified = is_data_constant<uint32_t>(constant);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
modified = is_data_constant<uint64_t>(constant);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
}
} }
return modified; return modified;
} }
...@@ -208,6 +208,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty ...@@ -208,6 +208,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name(); ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
} }
......
...@@ -661,16 +661,8 @@ static shared_ptr<ngraph::Function> ...@@ -661,16 +661,8 @@ static shared_ptr<ngraph::Function>
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js; node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto value_it = node_js.find("value"); auto value = node_js.at("value").get<vector<string>>();
if (value_it != node_js.end())
{
auto value = value_it->get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value); node = make_shared<op::Constant>(element_type, shape, value);
}
else
{
node = const_data_callback(node_name, element_type, shape);
}
break; break;
} }
case OP_TYPEID::Convert: case OP_TYPEID::Convert:
...@@ -1681,7 +1673,13 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1681,7 +1673,13 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Constant: case OP_TYPEID::Constant:
{ {
auto tmp = dynamic_cast<const op::Constant*>(&n); auto tmp = dynamic_cast<const op::Constant*>(&n);
if (!binary_constant_data) if (tmp->are_all_data_elements_bitwise_identical())
{
vector<string> vs;
vs.push_back(tmp->get_value_strings()[0]);
node["value"] = vs;
}
else
{ {
node["value"] = tmp->get_value_strings(); node["value"] = tmp->get_value_strings();
} }
......
...@@ -893,14 +893,14 @@ TEST(all_close_f, inf_nan) ...@@ -893,14 +893,14 @@ TEST(all_close_f, inf_nan)
EXPECT_FALSE(test::close_f(zero, signaling_nan)); EXPECT_FALSE(test::close_f(zero, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({zero}), vector<float>({signaling_nan}))); EXPECT_FALSE(test::all_close_f(vector<float>({zero}), vector<float>({signaling_nan})));
EXPECT_FALSE(test::close_f(infinity, infinity)); EXPECT_TRUE(test::close_f(infinity, infinity));
EXPECT_FALSE(test::all_close_f(vector<float>({infinity}), vector<float>({infinity}))); EXPECT_TRUE(test::all_close_f(vector<float>({infinity}), vector<float>({infinity})));
EXPECT_FALSE(test::close_f(neg_infinity, neg_infinity)); EXPECT_TRUE(test::close_f(neg_infinity, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<float>({neg_infinity}), vector<float>({neg_infinity}))); EXPECT_TRUE(test::all_close_f(vector<float>({neg_infinity}), vector<float>({neg_infinity})));
EXPECT_FALSE(test::close_f(quiet_nan, quiet_nan)); EXPECT_TRUE(test::close_f(quiet_nan, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({quiet_nan}), vector<float>({quiet_nan}))); EXPECT_TRUE(test::all_close_f(vector<float>({quiet_nan}), vector<float>({quiet_nan})));
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan)); EXPECT_TRUE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan}))); EXPECT_TRUE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan})));
} }
TEST(all_close_f, double_inf_nan) TEST(all_close_f, double_inf_nan)
...@@ -920,13 +920,13 @@ TEST(all_close_f, double_inf_nan) ...@@ -920,13 +920,13 @@ TEST(all_close_f, double_inf_nan)
EXPECT_FALSE(test::close_f(zero, signaling_nan)); EXPECT_FALSE(test::close_f(zero, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({signaling_nan}))); EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({signaling_nan})));
EXPECT_FALSE(test::close_f(infinity, infinity)); EXPECT_TRUE(test::close_f(infinity, infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({infinity}), vector<double>({infinity}))); EXPECT_TRUE(test::all_close_f(vector<double>({infinity}), vector<double>({infinity})));
EXPECT_FALSE(test::close_f(neg_infinity, neg_infinity)); EXPECT_TRUE(test::close_f(neg_infinity, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({neg_infinity}), vector<double>({neg_infinity}))); EXPECT_TRUE(test::all_close_f(vector<double>({neg_infinity}), vector<double>({neg_infinity})));
EXPECT_FALSE(test::close_f(quiet_nan, quiet_nan)); EXPECT_TRUE(test::close_f(quiet_nan, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({quiet_nan}), vector<double>({quiet_nan}))); EXPECT_TRUE(test::all_close_f(vector<double>({quiet_nan}), vector<double>({quiet_nan})));
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan)); EXPECT_TRUE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE( EXPECT_TRUE(
test::all_close_f(vector<double>({signaling_nan}), vector<double>({signaling_nan}))); test::all_close_f(vector<double>({signaling_nan}), vector<double>({signaling_nan})));
} }
...@@ -22,11 +22,13 @@ ...@@ -22,11 +22,13 @@
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace std; using namespace std;
...@@ -256,3 +258,55 @@ TEST(serialize, passthrough) ...@@ -256,3 +258,55 @@ TEST(serialize, passthrough)
ElementsAre(IsOutputShape(element::f32, Shape{2, 3}), ElementsAre(IsOutputShape(element::f32, Shape{2, 3}),
IsOutputShape(element::i8, Shape{4, 5}))); IsOutputShape(element::i8, Shape{4, 5})));
} }
TEST(serialize, constant_infinity_nan)
{
vector<float> a_data{123, 456, INFINITY, -INFINITY, NAN};
vector<float> b_data{5, 5, 5, 5, 5, 5};
vector<float> c_data{0.05, 0.05, 0.05, 0.05, 0.05, 0.05001, 0.05};
vector<int64_t> d_data{-100, -10, -1, 0, 50, 5000000000001};
auto A = make_shared<op::Constant>(element::f32, Shape{5}, a_data);
auto B = make_shared<op::Constant>(element::f32, Shape{6}, b_data);
auto C = make_shared<op::Constant>(element::f32, Shape{7}, c_data);
auto D = make_shared<op::Constant>(element::i64, Shape{d_data.size()}, d_data);
A->set_friendly_name("A");
B->set_friendly_name("B");
C->set_friendly_name("C");
D->set_friendly_name("D");
auto f = make_shared<Function>(NodeVector{A, B, C, D}, ParameterVector{});
string s = serialize(f, 4);
shared_ptr<Function> g = deserialize(s);
shared_ptr<op::Constant> a;
shared_ptr<op::Constant> b;
shared_ptr<op::Constant> c;
shared_ptr<op::Constant> d;
for (auto node : g->get_ops())
{
if (node->get_friendly_name() == "A")
{
a = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "B")
{
b = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "C")
{
c = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "D")
{
d = static_pointer_cast<op::Constant>(node);
}
}
ASSERT_NE(a, nullptr);
ASSERT_NE(b, nullptr);
ASSERT_NE(c, nullptr);
ASSERT_NE(d, nullptr);
EXPECT_TRUE(test::all_close_f(a->get_vector<float>(), a_data));
EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data));
EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data));
EXPECT_EQ(d->get_vector<int64_t>(), d_data);
}
...@@ -39,8 +39,20 @@ constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1; ...@@ -39,8 +39,20 @@ constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;
uint32_t test::float_distance(float a, float b, float min_signal) uint32_t test::float_distance(float a, float b, float min_signal)
{ {
if (!isfinite(a) || !isfinite(b)) if (std::isnan(a) && std::isnan(b))
{ {
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return FLOAT_MAX_DIFF; return FLOAT_MAX_DIFF;
} }
...@@ -82,8 +94,20 @@ uint32_t test::float_distance(float a, float b, float min_signal) ...@@ -82,8 +94,20 @@ uint32_t test::float_distance(float a, float b, float min_signal)
uint64_t test::float_distance(double a, double b, double min_signal) uint64_t test::float_distance(double a, double b, double min_signal)
{ {
if (!isfinite(a) || !isfinite(b)) if (std::isnan(a) && std::isnan(b))
{ {
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return DOUBLE_MAX_DIFF; return DOUBLE_MAX_DIFF;
} }
...@@ -125,9 +149,20 @@ uint64_t test::float_distance(double a, double b, double min_signal) ...@@ -125,9 +149,20 @@ uint64_t test::float_distance(double a, double b, double min_signal)
bool test::close_f(float a, float b, int tolerance_bits, float min_signal) bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
{ {
// isfinite(a) => !isinf(a) && !isnan(a) if (std::isnan(a) && std::isnan(b))
if (!isfinite(a) || !isfinite(b))
{ {
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false; return false;
} }
...@@ -144,9 +179,20 @@ bool test::close_f(float a, float b, int tolerance_bits, float min_signal) ...@@ -144,9 +179,20 @@ bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
bool test::close_f(double a, double b, int tolerance_bits, double min_signal) bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
{ {
// isfinite(a) => !isinf(a) && !isnan(a) if (std::isnan(a) && std::isnan(b))
if (!isfinite(a) || !isfinite(b))
{ {
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false; return false;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment