Unverified Commit 2c23cf20 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Allow Constant ops to share internal buffer (#4216)

* Move non-templated constructor implementation to the source file

* Optimize constant constructor for uniform constant

* Cleanup

* Much faster deserialize constant

* Adding unit tests

* Unit tests

* Update unit test

* Cleanup

* style

* Cleanup nbench output

* wip

* Fix specializations

* Change from unique to shared_ptr internally

* Enable copy of Constant

* cleanup Constant ctors

* Fix copy contructor

* Fix compile error
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 14cac0dd
......@@ -52,8 +52,8 @@ op::Constant::Constant(const element::Type& type,
const std::vector<std::string>& values)
: m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f), host_alignment()))
, m_data(
new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), host_alignment()))
{
NODE_VALIDATION_CHECK(this,
values.size() == shape_size(m_shape) || values.size() == 1,
......@@ -290,15 +290,24 @@ op::Constant::Constant(const element::Type& type,
op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data)
: m_element_type(type)
, m_shape(shape)
, m_data(nullptr)
, m_data(
new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), host_alignment()))
{
size_t size = std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f);
m_data.reset(new runtime::AlignedBuffer(size, host_alignment()));
size_t size = shape_size(m_shape) * m_element_type.size();
std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
op::Constant::Constant(const Constant& other)
: m_element_type(other.m_element_type)
, m_shape(other.m_shape)
, m_data(other.m_data)
, m_all_elements_bitwise_identical(other.m_all_elements_bitwise_identical)
{
constructor_validate_and_infer_types();
}
op::Constant::~Constant()
{
}
......@@ -516,7 +525,7 @@ AxisSet op::Constant::get_axis_set_val() const
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Constant>(m_element_type, m_shape, m_data->get_ptr());
return make_shared<Constant>(*this);
}
template <typename T>
......
......@@ -49,9 +49,8 @@ namespace ngraph
Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
host_alignment()))
{
NODE_VALIDATION_CHECK(
this,
......@@ -94,6 +93,8 @@ namespace ngraph
/// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data);
Constant(const Constant& other);
virtual ~Constant() override;
void validate_and_infer_types() override
......@@ -374,11 +375,9 @@ namespace ngraph
static constexpr size_t host_alignment() { return 64; }
element::Type m_element_type;
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
std::shared_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
/// \brief A scalar constant whose element type is the same as like.
......
......@@ -1010,3 +1010,13 @@ TEST(constant, float16_vector_broadcast)
EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(1));
}
TEST(constant, shared_data)
{
Shape shape{100, 200};
auto c1 = make_shared<op::Constant>(element::f16, shape, vector<float16>{123});
auto c2 = static_pointer_cast<op::Constant>(c1->copy_with_new_args({}));
const float* p1 = c1->get_data_ptr<float>();
const float* p2 = c2->get_data_ptr<float>();
EXPECT_EQ(p1, p2);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment