Unverified Commit 4971bdf1 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Much faster serialize/deserialize of broadcast constants (#2993)

* serialize constant faster

* more speedup
parent 36422810
......@@ -49,6 +49,42 @@ op::Constant::~Constant()
{
}
string op::Constant::convert_value_to_string(size_t index) const
{
string rc;
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_element_type().get_type_enum())
{
case element::Type_t::boolean: rc = to_string(get_vector<char>()[index]); break;
case element::Type_t::bf16:
rc = to_cpp_string(static_cast<float>(get_vector<bfloat16>()[index]));
break;
case element::Type_t::f16:
rc = to_cpp_string(static_cast<float>(get_vector<float16>()[index]));
break;
case element::Type_t::f32: rc = to_cpp_string(get_vector<float>()[index]); break;
case element::Type_t::f64: rc = to_cpp_string(get_vector<double>()[index]); break;
case element::Type_t::i8: rc = to_string(get_vector<int8_t>()[index]); break;
case element::Type_t::i16: rc = to_string(get_vector<int16_t>()[index]); break;
case element::Type_t::i32: rc = to_string(get_vector<int32_t>()[index]); break;
case element::Type_t::i64: rc = to_string(get_vector<int64_t>()[index]); break;
case element::Type_t::u8: rc = to_string(get_vector<uint8_t>()[index]); break;
case element::Type_t::u16: rc = to_string(get_vector<uint16_t>()[index]); break;
case element::Type_t::u32: rc = to_string(get_vector<uint32_t>()[index]); break;
case element::Type_t::u64: rc = to_string(get_vector<uint64_t>()[index]); break;
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::dynamic: throw runtime_error("unsupported type");
}
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return rc;
}
vector<string> op::Constant::get_value_strings() const
{
vector<string> rc;
......
......@@ -95,32 +95,43 @@ namespace ngraph
shape_size(m_shape),
".");
std::vector<std::string> tmp_values;
if (values.size() == 1 && shape_size(m_shape) != 1)
{
tmp_values = std::vector<std::string>(shape_size(m_shape), values[0]);
}
else
{
tmp_values = values;
}
if (type.is_integral())
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(tmp_values);
std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues);
}
else
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(tmp_values);
std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues);
}
}
else
{
std::vector<double> dvalues = parse_string<double>(tmp_values);
std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
for (size_t i = 1; i < shape_size(m_shape); i++)
{
dvalues.push_back(dvalues[0]);
}
}
write_values(dvalues);
}
constructor_validate_and_infer_types();
......@@ -243,6 +254,7 @@ namespace ngraph
bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
std::string convert_value_to_string(size_t index) const;
protected:
void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
......
......@@ -1797,7 +1797,7 @@ static json write(const Node& n, bool binary_constant_data)
if (tmp->are_all_data_elements_bitwise_identical())
{
vector<string> vs;
vs.push_back(tmp->get_value_strings()[0]);
vs.push_back(tmp->convert_value_to_string(0));
node["value"] = vs;
}
else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment