Unverified Commit 4971bdf1 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Much faster serialize/deserialize of broadcast constants (#2993)

* serialize constant faster

* more speedup
parent 36422810
...@@ -49,6 +49,42 @@ op::Constant::~Constant() ...@@ -49,6 +49,42 @@ op::Constant::~Constant()
{ {
} }
string op::Constant::convert_value_to_string(size_t index) const
{
string rc;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean: rc = to_string(get_vector<char>()[index]); break;
case element::Type_t::bf16:
rc = to_cpp_string(static_cast<float>(get_vector<bfloat16>()[index]));
break;
case element::Type_t::f16:
rc = to_cpp_string(static_cast<float>(get_vector<float16>()[index]));
break;
case element::Type_t::f32: rc = to_cpp_string(get_vector<float>()[index]); break;
case element::Type_t::f64: rc = to_cpp_string(get_vector<double>()[index]); break;
case element::Type_t::i8: rc = to_string(get_vector<int8_t>()[index]); break;
case element::Type_t::i16: rc = to_string(get_vector<int16_t>()[index]); break;
case element::Type_t::i32: rc = to_string(get_vector<int32_t>()[index]); break;
case element::Type_t::i64: rc = to_string(get_vector<int64_t>()[index]); break;
case element::Type_t::u8: rc = to_string(get_vector<uint8_t>()[index]); break;
case element::Type_t::u16: rc = to_string(get_vector<uint16_t>()[index]); break;
case element::Type_t::u32: rc = to_string(get_vector<uint32_t>()[index]); break;
case element::Type_t::u64: rc = to_string(get_vector<uint64_t>()[index]); break;
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type");
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return rc;
}
vector<string> op::Constant::get_value_strings() const vector<string> op::Constant::get_value_strings() const
{ {
vector<string> rc; vector<string> rc;
......
...@@ -95,32 +95,43 @@ namespace ngraph ...@@ -95,32 +95,43 @@ namespace ngraph
shape_size(m_shape), shape_size(m_shape),
"."); ".");
std::vector<std::string> tmp_values;
if (values.size() == 1 && shape_size(m_shape) != 1)
{
tmp_values = std::vector<std::string>(shape_size(m_shape), values[0]);
}
else
{
tmp_values = values;
}
if (type.is_integral()) if (type.is_integral())
{ {
if (type.is_signed()) if (type.is_signed())
{ {
std::vector<int64_t> dvalues = parse_string<int64_t>(tmp_values); std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues); write_values(dvalues);
} }
else else
{ {
std::vector<uint64_t> dvalues = parse_string<uint64_t>(tmp_values); std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues); write_values(dvalues);
} }
} }
else else
{ {
std::vector<double> dvalues = parse_string<double>(tmp_values); std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues); write_values(dvalues);
} }
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -243,6 +254,7 @@ namespace ngraph ...@@ -243,6 +254,7 @@ namespace ngraph
bool is_constant() const override { return true; } bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const; bool are_all_data_elements_bitwise_identical() const;
std::string convert_value_to_string(size_t index) const;
protected: protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); } void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
......
...@@ -1797,7 +1797,7 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1797,7 +1797,7 @@ static json write(const Node& n, bool binary_constant_data)
if (tmp->are_all_data_elements_bitwise_identical()) if (tmp->are_all_data_elements_bitwise_identical())
{ {
vector<string> vs; vector<string> vs;
vs.push_back(tmp->get_value_strings()[0]); vs.push_back(tmp->convert_value_to_string(0));
node["value"] = vs; node["value"] = vs;
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment