Commit 7c337e5d authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Store Constant arrays where data is constant as a single value (#2880)

* wip

* Add support for storing constant array of constant values as a single values that is automatically broadcast on deserialize

* revert some changes to serializer.cpp

* fix all_close_f to support nan and inf to allow for unit test

* update unit tests to pass for all_close_f update

* fix bug with i64

* address compile issues?

* change function name to be more accurate

* fix compiler error
parent ae352fa4
......@@ -34,14 +34,7 @@ string to_cpp_string(T value)
}
else if (std::isinf(value))
{
if (value > 0)
{
rc = "INFINITY";
}
else
{
rc = "-INFINITY";
}
rc = (value > 0 ? "INFINITY" : "-INFINITY");
}
else
{
......@@ -60,96 +53,93 @@ vector<string> op::Constant::get_value_strings() const
{
vector<string> rc;
if (m_element_type == element::boolean)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean:
for (int value : get_vector<char>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::bf16)
{
float temp = 0;
break;
case element::Type_t::bf16:
for (bfloat16 value : get_vector<bfloat16>())
{
temp = static_cast<float>(value);
rc.push_back(to_cpp_string(temp));
rc.push_back(to_cpp_string(static_cast<float>(value)));
}
}
else if (m_element_type == element::f32)
{
break;
case element::Type_t::f16:
for (float16 value : get_vector<float16>())
{
rc.push_back(to_cpp_string(static_cast<float>(value)));
}
break;
case element::Type_t::f32:
for (float value : get_vector<float>())
{
rc.push_back(to_cpp_string(value));
}
}
else if (m_element_type == element::f64)
{
break;
case element::Type_t::f64:
for (double value : get_vector<double>())
{
rc.push_back(to_cpp_string(value));
}
}
else if (m_element_type == element::i8)
{
break;
case element::Type_t::i8:
for (int value : get_vector<int8_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i16)
{
break;
case element::Type_t::i16:
for (int value : get_vector<int16_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i32)
{
break;
case element::Type_t::i32:
for (int32_t value : get_vector<int32_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::i64)
{
break;
case element::Type_t::i64:
for (int64_t value : get_vector<int64_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u8)
{
break;
case element::Type_t::u8:
for (uint32_t value : get_vector<uint8_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u16)
{
break;
case element::Type_t::u16:
for (uint32_t value : get_vector<uint16_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u32)
{
break;
case element::Type_t::u32:
for (uint32_t value : get_vector<uint32_t>())
{
rc.push_back(to_string(value));
}
}
else if (m_element_type == element::u64)
{
break;
case element::Type_t::u64:
for (uint64_t value : get_vector<uint64_t>())
{
rc.push_back(to_string(value));
}
break;
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type");
}
else
{
throw runtime_error("unsupported type");
}
#pragma GCC diagnostic pop
return rc;
}
......@@ -160,6 +150,71 @@ shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) co
return make_shared<Constant>(m_element_type, m_shape, m_data->get_ptr());
}
template <typename T>
static bool test_bitwise_identical(const op::Constant* constant)
{
const size_t size = shape_size(constant->get_shape());
bool data_is_constant = true;
if (size > 0)
{
const T* data = constant->get_data_ptr<T>();
const T compare = data[0];
for (size_t i = 1; i < size; i++)
{
if (data[i] != compare)
{
data_is_constant = false;
break;
}
}
}
return data_is_constant;
}
bool op::Constant::are_all_data_elements_bitwise_identical() const
{
bool rc = false;
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
rc = test_bitwise_identical<uint8_t>(this);
break;
}
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::i16:
case element::Type_t::u16:
{
rc = test_bitwise_identical<uint16_t>(this);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
rc = test_bitwise_identical<uint32_t>(this);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
rc = test_bitwise_identical<uint64_t>(this);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}
#pragma GCC diagnostic pop
return rc;
}
shared_ptr<op::Constant> op::ScalarConstantLikeBase::as_constant() const
{
return std::make_shared<op::Constant>(m_element_type, m_shape, m_data->get_ptr());
......
......@@ -85,7 +85,7 @@ namespace ngraph
{
NODE_VALIDATION_CHECK(
this,
values.size() == shape_size(m_shape),
values.size() == shape_size(m_shape) || values.size() == 1,
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
......@@ -94,8 +94,34 @@ namespace ngraph
shape_size(m_shape),
".");
std::vector<double> dvalues = parse_string<double>(values);
write_values(dvalues);
std::vector<std::string> tmp_values;
if (values.size() == 1 && shape_size(m_shape) != 1)
{
tmp_values = std::vector<std::string>(shape_size(m_shape), values[0]);
}
else
{
tmp_values = values;
}
if (type.is_integral())
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(tmp_values);
write_values(dvalues);
}
else
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(tmp_values);
write_values(dvalues);
}
}
else
{
std::vector<double> dvalues = parse_string<double>(tmp_values);
write_values(dvalues);
}
constructor_validate_and_infer_types();
}
......@@ -185,6 +211,8 @@ namespace ngraph
}
bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
Constant(const std::string& name, const NodeVector& args)
......@@ -222,58 +250,54 @@ namespace ngraph
{
throw std::runtime_error("Constant initializer does not match shape");
}
if (target_type == element::boolean)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
switch (target_type.get_type_enum())
{
case element::Type_t::boolean:
write_buffer<char, T>(target, source, target_element_count);
}
else if (target_type == element::bf16)
{
break;
case element::Type_t::bf16:
write_buffer<bfloat16, T>(target, source, target_element_count);
}
else if (target_type == element::f32)
{
break;
case element::Type_t::f16:
write_buffer<float16, T>(target, source, target_element_count);
break;
case element::Type_t::f32:
write_buffer<float, T>(target, source, target_element_count);
}
else if (target_type == element::f64)
{
break;
case element::Type_t::f64:
write_buffer<double, T>(target, source, target_element_count);
}
else if (target_type == element::i8)
{
break;
case element::Type_t::i8:
write_buffer<int8_t, T>(target, source, target_element_count);
}
else if (target_type == element::i16)
{
break;
case element::Type_t::i16:
write_buffer<int16_t, T>(target, source, target_element_count);
}
else if (target_type == element::i32)
{
break;
case element::Type_t::i32:
write_buffer<int32_t, T>(target, source, target_element_count);
}
else if (target_type == element::i64)
{
break;
case element::Type_t::i64:
write_buffer<int64_t, T>(target, source, target_element_count);
}
else if (target_type == element::u8)
{
break;
case element::Type_t::u8:
write_buffer<uint8_t, T>(target, source, target_element_count);
}
else if (target_type == element::u16)
{
break;
case element::Type_t::u16:
write_buffer<uint16_t, T>(target, source, target_element_count);
}
else if (target_type == element::u32)
{
break;
case element::Type_t::u32:
write_buffer<uint32_t, T>(target, source, target_element_count);
}
else if (target_type == element::u64)
{
break;
case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
else
{
throw std::runtime_error("unsupported type");
}
#pragma GCC diagnostic pop
}
static constexpr size_t host_alignment() { return 64; }
......
......@@ -22,40 +22,6 @@
using namespace std;
using namespace ngraph;
template <typename T>
static bool is_data_constant(shared_ptr<op::Constant> constant)
{
const size_t size = shape_size(constant->get_shape());
bool data_is_constant = true;
if (size > 0)
{
const T* data = constant->get_data_ptr<T>();
const T compare = data[0];
for (size_t i = 1; i < size; i++)
{
if (data[i] != compare)
{
data_is_constant = false;
break;
}
}
if (data_is_constant)
{
auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr());
AxisSet broadcast_axes;
for (size_t i = 0; i < constant->get_output_shape(0).size(); i++)
{
broadcast_axes.insert(i);
}
auto broadcast = make_shared<op::Broadcast>(
scalar_constant, constant->get_output_shape(0), broadcast_axes);
replace_node(constant, broadcast);
}
}
return data_is_constant;
}
bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
{
const size_t minimum_size_of_interest = 32;
......@@ -66,38 +32,18 @@ bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
size_t size = shape_size(constant->get_shape());
if (size > minimum_size_of_interest)
{
switch (constant->get_element_type().get_type_enum())
{
case element::Type_t::boolean:
case element::Type_t::i8:
case element::Type_t::u8:
{
modified = is_data_constant<uint8_t>(constant);
break;
}
case element::Type_t::bf16:
case element::Type_t::i16:
case element::Type_t::u16:
if (constant->are_all_data_elements_bitwise_identical())
{
modified = is_data_constant<uint16_t>(constant);
break;
}
case element::Type_t::f32:
case element::Type_t::i32:
case element::Type_t::u32:
{
modified = is_data_constant<uint32_t>(constant);
break;
}
case element::Type_t::f64:
case element::Type_t::i64:
case element::Type_t::u64:
{
modified = is_data_constant<uint64_t>(constant);
break;
}
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
auto scalar_constant = make_shared<op::Constant>(
constant->get_element_type(), Shape{}, constant->get_data_ptr());
AxisSet broadcast_axes;
for (size_t i = 0; i < constant->get_output_shape(0).size(); i++)
{
broadcast_axes.insert(i);
}
auto broadcast = make_shared<op::Broadcast>(
scalar_constant, constant->get_output_shape(0), broadcast_axes);
replace_node(constant, broadcast);
}
}
}
......
......@@ -208,6 +208,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
throw ngraph_error(ss.str());
}
......
......@@ -661,16 +661,8 @@ static shared_ptr<ngraph::Function>
node_js.count("element_type") == 0 ? node_js.at("value_type") : node_js;
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto value_it = node_js.find("value");
if (value_it != node_js.end())
{
auto value = value_it->get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
}
else
{
node = const_data_callback(node_name, element_type, shape);
}
auto value = node_js.at("value").get<vector<string>>();
node = make_shared<op::Constant>(element_type, shape, value);
break;
}
case OP_TYPEID::Convert:
......@@ -1681,7 +1673,13 @@ static json write(const Node& n, bool binary_constant_data)
case OP_TYPEID::Constant:
{
auto tmp = dynamic_cast<const op::Constant*>(&n);
if (!binary_constant_data)
if (tmp->are_all_data_elements_bitwise_identical())
{
vector<string> vs;
vs.push_back(tmp->get_value_strings()[0]);
node["value"] = vs;
}
else
{
node["value"] = tmp->get_value_strings();
}
......
......@@ -893,14 +893,14 @@ TEST(all_close_f, inf_nan)
EXPECT_FALSE(test::close_f(zero, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({zero}), vector<float>({signaling_nan})));
EXPECT_FALSE(test::close_f(infinity, infinity));
EXPECT_FALSE(test::all_close_f(vector<float>({infinity}), vector<float>({infinity})));
EXPECT_FALSE(test::close_f(neg_infinity, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<float>({neg_infinity}), vector<float>({neg_infinity})));
EXPECT_FALSE(test::close_f(quiet_nan, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({quiet_nan}), vector<float>({quiet_nan})));
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan})));
EXPECT_TRUE(test::close_f(infinity, infinity));
EXPECT_TRUE(test::all_close_f(vector<float>({infinity}), vector<float>({infinity})));
EXPECT_TRUE(test::close_f(neg_infinity, neg_infinity));
EXPECT_TRUE(test::all_close_f(vector<float>({neg_infinity}), vector<float>({neg_infinity})));
EXPECT_TRUE(test::close_f(quiet_nan, quiet_nan));
EXPECT_TRUE(test::all_close_f(vector<float>({quiet_nan}), vector<float>({quiet_nan})));
EXPECT_TRUE(test::close_f(signaling_nan, signaling_nan));
EXPECT_TRUE(test::all_close_f(vector<float>({signaling_nan}), vector<float>({signaling_nan})));
}
TEST(all_close_f, double_inf_nan)
......@@ -920,13 +920,13 @@ TEST(all_close_f, double_inf_nan)
EXPECT_FALSE(test::close_f(zero, signaling_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({zero}), vector<double>({signaling_nan})));
EXPECT_FALSE(test::close_f(infinity, infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({infinity}), vector<double>({infinity})));
EXPECT_FALSE(test::close_f(neg_infinity, neg_infinity));
EXPECT_FALSE(test::all_close_f(vector<double>({neg_infinity}), vector<double>({neg_infinity})));
EXPECT_FALSE(test::close_f(quiet_nan, quiet_nan));
EXPECT_FALSE(test::all_close_f(vector<double>({quiet_nan}), vector<double>({quiet_nan})));
EXPECT_FALSE(test::close_f(signaling_nan, signaling_nan));
EXPECT_FALSE(
EXPECT_TRUE(test::close_f(infinity, infinity));
EXPECT_TRUE(test::all_close_f(vector<double>({infinity}), vector<double>({infinity})));
EXPECT_TRUE(test::close_f(neg_infinity, neg_infinity));
EXPECT_TRUE(test::all_close_f(vector<double>({neg_infinity}), vector<double>({neg_infinity})));
EXPECT_TRUE(test::close_f(quiet_nan, quiet_nan));
EXPECT_TRUE(test::all_close_f(vector<double>({quiet_nan}), vector<double>({quiet_nan})));
EXPECT_TRUE(test::close_f(signaling_nan, signaling_nan));
EXPECT_TRUE(
test::all_close_f(vector<double>({signaling_nan}), vector<double>({signaling_nan})));
}
......@@ -22,11 +22,13 @@
#include "ngraph/file_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace std;
......@@ -256,3 +258,55 @@ TEST(serialize, passthrough)
ElementsAre(IsOutputShape(element::f32, Shape{2, 3}),
IsOutputShape(element::i8, Shape{4, 5})));
}
TEST(serialize, constant_infinity_nan)
{
vector<float> a_data{123, 456, INFINITY, -INFINITY, NAN};
vector<float> b_data{5, 5, 5, 5, 5, 5};
vector<float> c_data{0.05, 0.05, 0.05, 0.05, 0.05, 0.05001, 0.05};
vector<int64_t> d_data{-100, -10, -1, 0, 50, 5000000000001};
auto A = make_shared<op::Constant>(element::f32, Shape{5}, a_data);
auto B = make_shared<op::Constant>(element::f32, Shape{6}, b_data);
auto C = make_shared<op::Constant>(element::f32, Shape{7}, c_data);
auto D = make_shared<op::Constant>(element::i64, Shape{d_data.size()}, d_data);
A->set_friendly_name("A");
B->set_friendly_name("B");
C->set_friendly_name("C");
D->set_friendly_name("D");
auto f = make_shared<Function>(NodeVector{A, B, C, D}, ParameterVector{});
string s = serialize(f, 4);
shared_ptr<Function> g = deserialize(s);
shared_ptr<op::Constant> a;
shared_ptr<op::Constant> b;
shared_ptr<op::Constant> c;
shared_ptr<op::Constant> d;
for (auto node : g->get_ops())
{
if (node->get_friendly_name() == "A")
{
a = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "B")
{
b = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "C")
{
c = static_pointer_cast<op::Constant>(node);
}
else if (node->get_friendly_name() == "D")
{
d = static_pointer_cast<op::Constant>(node);
}
}
ASSERT_NE(a, nullptr);
ASSERT_NE(b, nullptr);
ASSERT_NE(c, nullptr);
ASSERT_NE(d, nullptr);
EXPECT_TRUE(test::all_close_f(a->get_vector<float>(), a_data));
EXPECT_TRUE(test::all_close_f(b->get_vector<float>(), b_data));
EXPECT_TRUE(test::all_close_f(c->get_vector<float>(), c_data));
EXPECT_EQ(d->get_vector<int64_t>(), d_data);
}
......@@ -39,8 +39,20 @@ constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;
uint32_t test::float_distance(float a, float b, float min_signal)
{
if (!isfinite(a) || !isfinite(b))
if (std::isnan(a) && std::isnan(b))
{
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return FLOAT_MAX_DIFF;
}
......@@ -82,8 +94,20 @@ uint32_t test::float_distance(float a, float b, float min_signal)
uint64_t test::float_distance(double a, double b, double min_signal)
{
if (!isfinite(a) || !isfinite(b))
if (std::isnan(a) && std::isnan(b))
{
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return 0;
}
else if (a < 0 && b < 0)
{
return 0;
}
return DOUBLE_MAX_DIFF;
}
......@@ -125,9 +149,20 @@ uint64_t test::float_distance(double a, double b, double min_signal)
bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
if (std::isnan(a) && std::isnan(b))
{
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false;
}
......@@ -144,9 +179,20 @@ bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
{
// isfinite(a) => !isinf(a) && !isnan(a)
if (!isfinite(a) || !isfinite(b))
if (std::isnan(a) && std::isnan(b))
{
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
return true;
}
else if (a < 0 && b < 0)
{
return true;
}
return false;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment