Commit d9a9ae69 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Optimize Constant for deserialization (#4208)

* Move non-templated constructor implementation to the source file

* Optimize constant constructor for uniform constant

* Cleanup

* Much faster deserialize constant

* Adding unit tests

* Unit tests

* Update unit test

* Cleanup

* style

* Cleanup nbench output

* Fix specializations
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 8b246c5d
......@@ -64,41 +64,227 @@ op::Constant::Constant(const element::Type& type,
", expected ",
shape_size(m_shape),
".");
if (values.size())
constructor_validate_and_infer_types();
if (values.size() == 1 && shape_size(m_shape) != 1)
{
if (type.is_integral())
// broadcast single value
switch (m_element_type)
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<int64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
else
case element::Type_t::boolean:
{
bool value = stoi(values[0]) != 0;
bool* target = m_data->get_ptr<bool>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::bf16:
{
bfloat16 value = parse_string<float>(values[0]);
bfloat16* target = m_data->get_ptr<bfloat16>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f16:
{
float16 value = parse_string<float>(values[0]);
float16* target = m_data->get_ptr<float16>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f32:
{
float value = parse_string<float>(values[0]);
float* target = m_data->get_ptr<float>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::f64:
{
double value = parse_string<double>(values[0]);
double* target = m_data->get_ptr<double>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i8:
{
int8_t value = parse_string<int64_t>(values[0]);
int8_t* target = m_data->get_ptr<int8_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i16:
{
int16_t value = parse_string<int64_t>(values[0]);
int16_t* target = m_data->get_ptr<int16_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i32:
{
int32_t value = parse_string<int64_t>(values[0]);
int32_t* target = m_data->get_ptr<int32_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::i64:
{
int64_t value = parse_string<int64_t>(values[0]);
int64_t* target = m_data->get_ptr<int64_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u8:
{
uint8_t value = parse_string<uint64_t>(values[0]);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u16:
{
uint16_t value = parse_string<uint64_t>(values[0]);
uint16_t* target = m_data->get_ptr<uint16_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u32:
{
uint32_t value = parse_string<uint64_t>(values[0]);
uint32_t* target = m_data->get_ptr<uint32_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::u64:
{
uint64_t value = parse_string<uint64_t>(values[0]);
uint64_t* target = m_data->get_ptr<uint64_t>();
std::fill(target, target + shape_size(m_shape), value);
break;
}
case element::Type_t::undefined:
{
throw std::runtime_error("deserialize unsupported type undefined");
}
case element::Type_t::dynamic:
{
throw std::runtime_error("deserialize unsupported type dynamic");
}
case element::Type_t::u1: { throw std::runtime_error("deserialize unsupported type u1");
}
}
m_all_elements_bitwise_identical = true;
}
else
{
switch (m_element_type)
{
case element::Type_t::boolean:
{
vector<uint8_t> value = parse_string<uint8_t>(values);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::bf16:
{
vector<float> value = parse_string<float>(values);
bfloat16* target = m_data->get_ptr<bfloat16>();
for (size_t i = 0; i < value.size(); i++)
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<uint64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
target[i] = value[i];
}
break;
}
else
case element::Type_t::f16:
{
std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
vector<float> value = parse_string<float>(values);
float16* target = m_data->get_ptr<float16>();
for (size_t i = 0; i < value.size(); i++)
{
dvalues = std::vector<double>(shape_size(m_shape), dvalues[0]);
target[i] = value[i];
}
write_values(dvalues);
break;
}
case element::Type_t::f32:
{
vector<float> value = parse_string<float>(values);
float* target = m_data->get_ptr<float>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::f64:
{
vector<double> value = parse_string<double>(values);
double* target = m_data->get_ptr<double>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i8:
{
vector<int8_t> value = parse_string<int8_t>(values);
int8_t* target = m_data->get_ptr<int8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i16:
{
vector<int16_t> value = parse_string<int16_t>(values);
int16_t* target = m_data->get_ptr<int16_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i32:
{
vector<int32_t> value = parse_string<int32_t>(values);
int32_t* target = m_data->get_ptr<int32_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::i64:
{
vector<int64_t> value = parse_string<int64_t>(values);
int64_t* target = m_data->get_ptr<int64_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u8:
{
vector<uint8_t> value = parse_string<uint8_t>(values);
uint8_t* target = m_data->get_ptr<uint8_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u16:
{
vector<uint16_t> value = parse_string<uint16_t>(values);
uint16_t* target = m_data->get_ptr<uint16_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u32:
{
vector<uint32_t> value = parse_string<uint32_t>(values);
uint32_t* target = m_data->get_ptr<uint32_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::u64:
{
vector<uint64_t> value = parse_string<uint64_t>(values);
uint64_t* target = m_data->get_ptr<uint64_t>();
std::copy(value.begin(), value.end(), target);
break;
}
case element::Type_t::undefined:
throw std::runtime_error("deserialize unsupported type undefined");
case element::Type_t::dynamic:
throw std::runtime_error("deserialize unsupported type dynamic");
case element::Type_t::u1: throw std::runtime_error("deserialize unsupported type u1");
}
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data)
......
......@@ -87,9 +87,7 @@ namespace ngraph
Shape shape,
const std::vector<std::string>& values);
/// \brief Constructs a tensor constant with the same initialization value copied
/// across the tensor. This constructor is to support deserialization of
/// constants.
/// \brief Constructs a tensor constant with the supplied data
///
/// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
......
......@@ -401,6 +401,36 @@ namespace ngraph
}
return result;
}
template <>
int8_t parse_string<int8_t>(const std::string& s)
{
char* err;
int8_t result = strtol(s.c_str(), &err, 10);
// Check that (1) parsing succeeded and (2) the entire string was used.
if (*err != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
template <>
uint8_t parse_string<uint8_t>(const std::string& s)
{
char* err;
uint8_t result = strtol(s.c_str(), &err, 10);
// Check that (1) parsing succeeded and (2) the entire string was used.
if (*err != 0)
{
throw std::runtime_error("Could not parse literal '" + s + "'");
}
return result;
}
}
std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv)
......
......@@ -163,6 +163,14 @@ namespace ngraph
template <>
double parse_string<double>(const std::string& s);
/// template specializations for int8_t and uint8_t to handle the fact that default
/// implementation ends up treating values as characters so that the number "0" turns into
/// the parsed value 48, which is it's ASCII value
template <>
int8_t parse_string<int8_t>(const std::string& s);
template <>
uint8_t parse_string<uint8_t>(const std::string& s);
/// Parses a list of strings containing literals of the underlying type.
template <typename T>
std::vector<T> parse_string(const std::vector<std::string>& ss)
......
......@@ -431,7 +431,13 @@ OPTIONS
if (!backend.empty())
{
cout << "\n---- Benchmark ----\n";
stopwatch t1;
t1.start();
shared_ptr<Function> f = deserialize(model);
stringstream ss;
ss.imbue(locale(""));
ss << t1.get_milliseconds();
cout << "deserialize took " << ss.str() << "ms\n";
vector<runtime::PerformanceCounter> perf_data;
if (double_buffer)
{
......
......@@ -51,6 +51,7 @@ set(SRC
build_graph.cpp
builder_autobroadcast.cpp
check.cpp
constant.cpp
constant_folding.cpp
concat_fusion.cpp
control_dependencies.cpp
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment