Unverified Commit 2c23cf20 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Allow Constant ops to share internal buffer (#4216)

* Move non-templated constructor implementation to the source file

* Optimize constant constructor for uniform constant

* Cleanup

* Much faster deserialize constant

* Adding unit tests

* Unit tests

* Update unit test

* Cleanup

* style

* Cleanup nbench output

* wip

* Fix specializations

* Change from unique to shared_ptr internally

* Enable copy of Constant

* cleanup Constant ctors

* Fix copy contructor

* Fix compile error
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 14cac0dd
...@@ -52,8 +52,8 @@ op::Constant::Constant(const element::Type& type, ...@@ -52,8 +52,8 @@ op::Constant::Constant(const element::Type& type,
const std::vector<std::string>& values) const std::vector<std::string>& values)
: m_element_type(type) : m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer( , m_data(
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f), host_alignment())) new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), host_alignment()))
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
values.size() == shape_size(m_shape) || values.size() == 1, values.size() == shape_size(m_shape) || values.size() == 1,
...@@ -290,15 +290,24 @@ op::Constant::Constant(const element::Type& type, ...@@ -290,15 +290,24 @@ op::Constant::Constant(const element::Type& type,
op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data) op::Constant::Constant(const element::Type& type, const Shape& shape, const void* data)
: m_element_type(type) : m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(nullptr) , m_data(
new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), host_alignment()))
{ {
size_t size = std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f); size_t size = shape_size(m_shape) * m_element_type.size();
m_data.reset(new runtime::AlignedBuffer(size, host_alignment()));
std::memcpy(m_data->get_ptr(), data, size); std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
} }
op::Constant::Constant(const Constant& other)
: m_element_type(other.m_element_type)
, m_shape(other.m_shape)
, m_data(other.m_data)
, m_all_elements_bitwise_identical(other.m_all_elements_bitwise_identical)
{
constructor_validate_and_infer_types();
}
op::Constant::~Constant() op::Constant::~Constant()
{ {
} }
...@@ -516,7 +525,7 @@ AxisSet op::Constant::get_axis_set_val() const ...@@ -516,7 +525,7 @@ AxisSet op::Constant::get_axis_set_val() const
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Constant>(m_element_type, m_shape, m_data->get_ptr()); return make_shared<Constant>(*this);
} }
template <typename T> template <typename T>
......
...@@ -49,8 +49,7 @@ namespace ngraph ...@@ -49,8 +49,7 @@ namespace ngraph
Constant(const element::Type& type, Shape shape, const std::vector<T>& values) Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: m_element_type(type) : m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer( , m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment())) host_alignment()))
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
...@@ -94,6 +93,8 @@ namespace ngraph ...@@ -94,6 +93,8 @@ namespace ngraph
/// \param data A void* to constant data. /// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data); Constant(const element::Type& type, const Shape& shape, const void* data);
Constant(const Constant& other);
virtual ~Constant() override; virtual ~Constant() override;
void validate_and_infer_types() override void validate_and_infer_types() override
...@@ -374,11 +375,9 @@ namespace ngraph ...@@ -374,11 +375,9 @@ namespace ngraph
static constexpr size_t host_alignment() { return 64; } static constexpr size_t host_alignment() { return 64; }
element::Type m_element_type; element::Type m_element_type;
Shape m_shape{}; Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data; std::shared_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical; bool m_all_elements_bitwise_identical;
bool are_all_data_elements_bitwise_identical() const; bool are_all_data_elements_bitwise_identical() const;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
}; };
/// \brief A scalar constant whose element type is the same as like. /// \brief A scalar constant whose element type is the same as like.
......
...@@ -1010,3 +1010,13 @@ TEST(constant, float16_vector_broadcast) ...@@ -1010,3 +1010,13 @@ TEST(constant, float16_vector_broadcast)
EXPECT_EQ(p[2], float16(1)); EXPECT_EQ(p[2], float16(1));
EXPECT_EQ(p[3], float16(1)); EXPECT_EQ(p[3], float16(1));
} }
TEST(constant, shared_data)
{
Shape shape{100, 200};
auto c1 = make_shared<op::Constant>(element::f16, shape, vector<float16>{123});
auto c2 = static_pointer_cast<op::Constant>(c1->copy_with_new_args({}));
const float* p1 = c1->get_data_ptr<float>();
const float* p2 = c2->get_data_ptr<float>();
EXPECT_EQ(p1, p2);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment