Commit 22008460 authored by Ilya Churaev's avatar Ilya Churaev Committed by Sang Ik Lee

Fixed constant operation for u1 format (#4045)

* Fixed bin constant ops

* Added export

* Fixed buffer size

* Fixed code style
parent 005660f0
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <cmath>
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
...@@ -46,8 +47,9 @@ namespace ngraph ...@@ -46,8 +47,9 @@ namespace ngraph
Constant(const element::Type& type, Shape shape, const std::vector<T>& values) Constant(const element::Type& type, Shape shape, const std::vector<T>& values)
: m_element_type(type) : m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), , m_data(new runtime::AlignedBuffer(
host_alignment())) std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
...@@ -82,8 +84,9 @@ namespace ngraph ...@@ -82,8 +84,9 @@ namespace ngraph
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values) Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values)
: m_element_type(type) : m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), , m_data(new runtime::AlignedBuffer(
host_alignment())) std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f),
host_alignment()))
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, this,
...@@ -143,9 +146,8 @@ namespace ngraph ...@@ -143,9 +146,8 @@ namespace ngraph
, m_shape(shape) , m_shape(shape)
, m_data(nullptr) , m_data(nullptr)
{ {
size_t size = shape_size(m_shape) * m_element_type.size(); size_t size = std::ceil(shape_size(m_shape) * m_element_type.bitwidth() / 8.f);
m_data.reset(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), m_data.reset(new runtime::AlignedBuffer(size, host_alignment()));
host_alignment()));
std::memcpy(m_data->get_ptr(), data, size); std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
......
...@@ -28,10 +28,9 @@ namespace ngraph ...@@ -28,10 +28,9 @@ namespace ngraph
/// ///
/// The reduction is performed over slices of the first input. The slices shape depends /// The reduction is performed over slices of the first input. The slices shape depends
/// on the values passed to the second input - the axes. /// on the values passed to the second input - the axes.
class ReduceLogicalAnd : public util::LogicalReductionKeepDims class NGRAPH_API ReduceLogicalAnd : public util::LogicalReductionKeepDims
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ReduceLogicalAnd", 1}; static constexpr NodeTypeInfo type_info{"ReduceLogicalAnd", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
ReduceLogicalAnd() = default; ReduceLogicalAnd() = default;
......
...@@ -28,10 +28,9 @@ namespace ngraph ...@@ -28,10 +28,9 @@ namespace ngraph
/// ///
/// The reduction is performed over slices of the first input. The slices shape depends /// The reduction is performed over slices of the first input. The slices shape depends
/// on the values passed to the second input - the axes. /// on the values passed to the second input - the axes.
class ReduceLogicalOr : public util::LogicalReductionKeepDims class NGRAPH_API ReduceLogicalOr : public util::LogicalReductionKeepDims
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"ReduceLogicalOr", 1}; static constexpr NodeTypeInfo type_info{"ReduceLogicalOr", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
ReduceLogicalOr() = default; ReduceLogicalOr() = default;
......
...@@ -98,6 +98,6 @@ namespace ngraph ...@@ -98,6 +98,6 @@ namespace ngraph
std::map<std::string, NodeTypeInfo> m_name_type_info_map; std::map<std::string, NodeTypeInfo> m_name_type_info_map;
}; };
const OpSet& get_opset0(); const NGRAPH_API OpSet& get_opset0();
const OpSet& get_opset1(); const NGRAPH_API OpSet& get_opset1();
} }
\ No newline at end of file
...@@ -103,6 +103,7 @@ namespace ngraph ...@@ -103,6 +103,7 @@ namespace ngraph
/// parameter_shapes[i] can be created. /// parameter_shapes[i] can be created.
/// ///
/// TODO(amprocte): convert this to a pass. /// TODO(amprocte): convert this to a pass.
NGRAPH_API
std::shared_ptr<Function> std::shared_ptr<Function>
specialize_function(std::shared_ptr<Function> f, specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
...@@ -195,6 +196,7 @@ namespace ngraph ...@@ -195,6 +196,7 @@ namespace ngraph
/// parameter_shapes[i] can be created. /// parameter_shapes[i] can be created.
/// ///
/// TODO(amprocte): convert this to a pass. /// TODO(amprocte): convert this to a pass.
NGRAPH_API
std::shared_ptr<Function> std::shared_ptr<Function>
specialize_function(std::shared_ptr<Function> f, specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
......
...@@ -32,4 +32,4 @@ TEST(convert_u1_to_string, convert_u1_to_string) ...@@ -32,4 +32,4 @@ TEST(convert_u1_to_string, convert_u1_to_string)
{ {
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]); ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
} }
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment