Commit ee54282a authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Move non-templated constructor implementation to the source file (#4126)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent b5e030d9
...@@ -47,6 +47,72 @@ string to_cpp_string(T value) ...@@ -47,6 +47,72 @@ string to_cpp_string(T value)
constexpr NodeTypeInfo op::Constant::type_info; constexpr NodeTypeInfo op::Constant::type_info;
op::Constant::Constant(const element::Type& type,
Shape shape,
const std::vector<std::string>& values)
: m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f), host_alignment()))
{
NODE_VALIDATION_CHECK(this,
values.size() == shape_size(m_shape) || values.size() == 1,
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
shape_size(m_shape),
".");
if (values.size())
{
if (type.is_integral())
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<int64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
else
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<uint64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
}
else
{
std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<double>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data)
: m_element_type(type)
, m_shape(shape)
, m_data(nullptr)
{
size_t size = std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f);
m_data.reset(new runtime::AlignedBuffer(size, host_alignment()));
std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
op::Constant::~Constant() op::Constant::~Constant()
{ {
} }
......
...@@ -81,59 +81,9 @@ namespace ngraph ...@@ -81,59 +81,9 @@ namespace ngraph
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data. /// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values) Constant(const element::Type& type,
: m_element_type(type) Shape shape,
, m_shape(shape) const std::vector<std::string>& values);
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
{
NODE_VALIDATION_CHECK(
this,
values.size() == shape_size(m_shape) || values.size() == 1,
"Did not get the expected number of literals for a constant of shape ",
m_shape,
" (got ",
values.size(),
", expected ",
shape_size(m_shape),
".");
if (values.size())
{
if (type.is_integral())
{
if (type.is_signed())
{
std::vector<int64_t> dvalues = parse_string<int64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<int64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
else
{
std::vector<uint64_t> dvalues = parse_string<uint64_t>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<uint64_t>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
}
else
{
std::vector<double> dvalues = parse_string<double>(values);
if (values.size() == 1 && shape_size(m_shape) != 1)
{
dvalues = std::vector<double>(shape_size(m_shape), dvalues[0]);
}
write_values(dvalues);
}
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
/// \brief Constructs a tensor constant with the same initialization value copied across /// \brief Constructs a tensor constant with the same initialization value copied across
// the tensor. This constructor is to support deserialization of constants. // the tensor. This constructor is to support deserialization of constants.
...@@ -141,17 +91,7 @@ namespace ngraph ...@@ -141,17 +91,7 @@ namespace ngraph
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param data A void* to constant data. /// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data) Constant(const element::Type& type, const Shape& shape, const void* data);
: m_element_type(type)
, m_shape(shape)
, m_data(nullptr)
{
size_t size = std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f);
m_data.reset(new runtime::AlignedBuffer(size, host_alignment()));
std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
virtual ~Constant() override; virtual ~Constant() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment