Commit c555b36a authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Remove tensor offset from tensor read/write calls because it was never used (#2979)

* remove tensor offset from tensor read/write calls because it was never used

* fix build errors

* fix build errors

* fix python test errors

* more python fixes

* revert change

* Make old version of read/write deprecated

* fix python read overload

* one more try to fix python binding

* fix python

* yet another try

* why is this so hard

* fix?

* add text to changes.md
parent 54bbb154
......@@ -53,6 +53,15 @@ need to be adapted.
## `Parameter` and `Function` no longer take a type argument.
## Changes to Tensor read and write methods
The `read` and `write` methods on ngraph::runtime::Tensor which take a `tensor_offset` as the
second of three arguments have been deprecated. The replacement `read` and `write` methods take
two arguments, the buffer pointer and the size. For any references to the deprecated methods
remove the second argument, the tensor offset, to update to the new API. These old read/write
methods have been decorated with deprecated warnings which may be enabled by setting
`-DNGRAPH_DEPRECATED_ENABLE=ON`.
To update, remove the passed argument. For example,
```C++
// Old
......
......@@ -133,11 +133,11 @@ class Computation(object):
tensor_view.element_type, tensor_view.element_count)
nparray = np.ascontiguousarray(value)
tensor_view.write(util.numpy_to_c(nparray), 0, buffer_size)
tensor_view.write(util.numpy_to_c(nparray), buffer_size)
@staticmethod
def _read_tensor_view_to_ndarray(tensor_view, output):
# type: (Tensor, np.ndarray) -> None
buffer_size = Computation._get_buffer_size(
tensor_view.element_type, tensor_view.element_count)
tensor_view.read(util.numpy_to_c(output), 0, buffer_size)
tensor_view.read(util.numpy_to_c(output), buffer_size)
......@@ -23,15 +23,23 @@
namespace py = pybind11;
static void read_(ngraph::runtime::Tensor* self, void* p, size_t n)
{
self->read(p, n);
}
static void write_(ngraph::runtime::Tensor* self, void* p, size_t n)
{
self->write(p, n);
}
void regclass_pyngraph_runtime_Tensor(py::module m)
{
py::class_<ngraph::runtime::Tensor, std::shared_ptr<ngraph::runtime::Tensor>> tensor(m,
"Tensor");
tensor.doc() = "ngraph.impl.runtime.Tensor wraps ngraph::runtime::Tensor";
tensor.def("write",
(void (ngraph::runtime::Tensor::*)(const void*, size_t, size_t)) &
ngraph::runtime::Tensor::write);
tensor.def("read", &ngraph::runtime::Tensor::read);
tensor.def("write", &write_);
tensor.def("read", &read_);
tensor.def_property_readonly("shape", &ngraph::runtime::Tensor::get_shape);
tensor.def_property_readonly("element_count", &ngraph::runtime::Tensor::get_element_count);
......
This diff is collapsed.
......@@ -98,43 +98,43 @@ namespace ngraph
case ONNXIFI_DATATYPE_FLOAT16:
case ONNXIFI_DATATYPE_FLOAT32:
tensor = backend.create_tensor(element::f32, m_shape);
tensor->write(data(), 0, sizeof(float) * size());
tensor->write(data(), sizeof(float) * size());
break;
case ONNXIFI_DATATYPE_FLOAT64:
tensor = backend.create_tensor(element::f64, m_shape);
tensor->write(data(), 0, sizeof(double) * size());
tensor->write(data(), sizeof(double) * size());
break;
case ONNXIFI_DATATYPE_INT8:
tensor = backend.create_tensor(element::i8, m_shape);
tensor->write(data(), 0, sizeof(int8_t) * size());
tensor->write(data(), sizeof(int8_t) * size());
break;
case ONNXIFI_DATATYPE_INT16:
tensor = backend.create_tensor(element::i16, m_shape);
tensor->write(data(), 0, sizeof(int16_t) * size());
tensor->write(data(), sizeof(int16_t) * size());
break;
case ONNXIFI_DATATYPE_INT32:
tensor = backend.create_tensor(element::i32, m_shape);
tensor->write(data(), 0, sizeof(int32_t) * size());
tensor->write(data(), sizeof(int32_t) * size());
break;
case ONNXIFI_DATATYPE_INT64:
tensor = backend.create_tensor(element::i64, m_shape);
tensor->write(data(), 0, sizeof(int64_t) * size());
tensor->write(data(), sizeof(int64_t) * size());
break;
case ONNXIFI_DATATYPE_UINT8:
tensor = backend.create_tensor(element::u8, m_shape);
tensor->write(data(), 0, sizeof(uint8_t) * size());
tensor->write(data(), sizeof(uint8_t) * size());
break;
case ONNXIFI_DATATYPE_UINT16:
tensor = backend.create_tensor(element::u16, m_shape);
tensor->write(data(), 0, sizeof(uint16_t) * size());
tensor->write(data(), sizeof(uint16_t) * size());
break;
case ONNXIFI_DATATYPE_UINT32:
tensor = backend.create_tensor(element::u32, m_shape);
tensor->write(data(), 0, sizeof(uint32_t) * size());
tensor->write(data(), sizeof(uint32_t) * size());
break;
case ONNXIFI_DATATYPE_UINT64:
tensor = backend.create_tensor(element::u64, m_shape);
tensor->write(data(), 0, sizeof(uint64_t) * size());
tensor->write(data(), sizeof(uint64_t) * size());
break;
default: throw status::unsupported_datatype{};
}
......@@ -159,7 +159,7 @@ namespace ngraph
case ONNXIFI_DATATYPE_UINT64: readSize *= sizeof(uint64_t); break;
default: break;
}
tensor.read(reinterpret_cast<void*>(m_tensor->buffer), 0, readSize);
tensor.read(reinterpret_cast<void*>(m_tensor->buffer), readSize);
}
} // namespace onnxifi
......
......@@ -92,19 +92,19 @@ const char* runtime::cpu::CPUTensorView::get_data_ptr() const
return aligned_buffer;
}
void runtime::cpu::CPUTensorView::write(const void* source, size_t tensor_offset, size_t n)
void runtime::cpu::CPUTensorView::write(const void* source, size_t n)
{
if (tensor_offset + n > buffer_size)
if (n > buffer_size)
{
throw out_of_range("write access past end of tensor");
}
char* target = get_data_ptr();
memcpy(&target[tensor_offset], source, n);
memcpy(target, source, n);
}
void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_t n) const
void runtime::cpu::CPUTensorView::read(void* target, size_t n) const
{
if (tensor_offset + n > buffer_size)
if (n > buffer_size)
{
throw out_of_range("read access past end of tensor");
}
......@@ -150,6 +150,6 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_
else
{
const char* source = get_data_ptr();
memcpy(target, &source[tensor_offset], n);
memcpy(target, source, n);
}
}
......@@ -44,15 +44,13 @@ namespace ngraph
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n) override;
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
void read(void* p, size_t n) const override;
static constexpr int BufferAlignment = NGRAPH_CPU_ALIGNMENT;
......
......@@ -113,7 +113,7 @@ bool runtime::dynamic::DynamicExecutable::call(
// TODO(amprocte): For host-resident tensors we should be able to skip the read,
// but no API for that yet.
input->read(arg_value_base_pointers[i], 0, input->get_size_in_bytes());
input->read(arg_value_base_pointers[i], input->get_size_in_bytes());
}
else
{
......@@ -225,18 +225,18 @@ const ngraph::Shape& runtime::dynamic::DynamicTensor::get_shape() const
return m_wrapped_tensor->get_shape();
}
void runtime::dynamic::DynamicTensor::write(const void* p, size_t offset, size_t n)
void runtime::dynamic::DynamicTensor::write(const void* p, size_t n)
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"tried to write to a dynamic tensor with no allocated storage");
m_wrapped_tensor->write(p, offset, n);
m_wrapped_tensor->write(p, n);
}
void runtime::dynamic::DynamicTensor::read(void* p, size_t offset, size_t n) const
void runtime::dynamic::DynamicTensor::read(void* p, size_t n) const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"tried to read from a dynamic tensor with no allocated storage");
m_wrapped_tensor->read(p, offset, n);
m_wrapped_tensor->read(p, n);
}
void runtime::dynamic::DynamicTensor::copy_from(const ngraph::runtime::Tensor& source)
......
......@@ -132,8 +132,8 @@ public:
virtual size_t get_element_count() const override;
virtual const element::Type& get_element_type() const override;
virtual const ngraph::Shape& get_shape() const override;
virtual void write(const void* p, size_t offset, size_t n) override;
virtual void read(void* p, size_t offset, size_t n) const override;
virtual void write(const void* p, size_t n) override;
virtual void read(void* p, size_t n) const override;
virtual void copy_from(const ngraph::runtime::Tensor& source) override;
bool has_storage() const;
void release_storage();
......
......@@ -68,12 +68,12 @@ runtime::gpu::GPUTensor::~GPUTensor()
}
}
void runtime::gpu::GPUTensor::write(const void* source, size_t tensor_offset, size_t n_bytes)
void runtime::gpu::GPUTensor::write(const void* source, size_t n_bytes)
{
runtime::gpu::cuda_memcpyHtD(m_allocated_buffer_pool, source, n_bytes);
}
void runtime::gpu::GPUTensor::read(void* target, size_t tensor_offset, size_t n_bytes) const
void runtime::gpu::GPUTensor::read(void* target, size_t n_bytes) const
{
runtime::gpu::cuda_memcpyDtH(target, m_allocated_buffer_pool, n_bytes);
}
......
......@@ -42,15 +42,13 @@ public:
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n_bytes Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n_bytes) override;
void write(const void* p, size_t n_bytes) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// \param n_bytes Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n_bytes) const override;
void read(void* p, size_t n_bytes) const override;
/// \brief Copy directly from the another GPU tensor
/// \param source Another GPU tensor
......
......@@ -94,22 +94,22 @@ const char* runtime::HostTensor::get_data_ptr() const
return m_aligned_buffer_pool;
}
void runtime::HostTensor::write(const void* source, size_t tensor_offset, size_t n)
void runtime::HostTensor::write(const void* source, size_t n)
{
if (tensor_offset + n > m_buffer_size)
if (n > m_buffer_size)
{
throw out_of_range("write access past end of tensor");
}
char* target = get_data_ptr();
memcpy(&target[tensor_offset], source, n);
memcpy(target, source, n);
}
void runtime::HostTensor::read(void* target, size_t tensor_offset, size_t n) const
void runtime::HostTensor::read(void* target, size_t n) const
{
if (tensor_offset + n > m_buffer_size)
if (n > m_buffer_size)
{
throw out_of_range("read access past end of tensor");
}
const char* source = get_data_ptr();
memcpy(target, &source[tensor_offset], n);
memcpy(target, source, n);
}
......@@ -61,15 +61,13 @@ public:
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n) override;
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
void read(void* p, size_t n) const override;
private:
HostTensor(const HostTensor&) = delete;
......
......@@ -78,9 +78,9 @@ const char* runtime::HybridTensor::get_data_ptr() const
return m_aligned_buffer_pool;
}
void runtime::HybridTensor::write(const void* source, size_t tensor_offset, size_t n)
void runtime::HybridTensor::write(const void* source, size_t n)
{
if (tensor_offset + n > m_buffer_size)
if (n > m_buffer_size)
{
throw out_of_range("write access past end of tensor");
}
......@@ -88,12 +88,12 @@ void runtime::HybridTensor::write(const void* source, size_t tensor_offset, size
memcpy(target, source, n);
}
void runtime::HybridTensor::read(void* target, size_t tensor_offset, size_t n) const
void runtime::HybridTensor::read(void* target, size_t n) const
{
if (tensor_offset + n > m_buffer_size)
if (n > m_buffer_size)
{
throw out_of_range("read access past end of tensor");
}
const char* source = get_data_ptr();
memcpy(target, &source[tensor_offset], n);
memcpy(target, source, n);
}
......@@ -56,15 +56,13 @@ public:
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n) override;
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
void read(void* p, size_t n) const override;
protected:
HybridTensor(const HybridTensor&) = delete;
......
......@@ -140,7 +140,7 @@ bool runtime::intelgpu::IntelGPUExecutable::call(const vector<shared_ptr<runtime
memory_size_check(result_memory.size(), dst_node, m_function->get_name());
ngraph_res->write(result_memory.data(), 0, result_memory.size());
ngraph_res->write(result_memory.data(), result_memory.size());
}
if (m_profile_enable)
......
......@@ -47,28 +47,26 @@ runtime::intelgpu::IntelGPUTensorView::IntelGPUTensorView(const element::Type& e
}
}
void runtime::intelgpu::IntelGPUTensorView::write(const void* source,
size_t tensor_offset,
size_t n)
void runtime::intelgpu::IntelGPUTensorView::write(const void* source, size_t n)
{
if (tensor_offset + n > ocl_memory->size())
if (n > ocl_memory->size())
{
throw out_of_range("write access past end of tensor");
}
auto ptr = ocl_memory->pointer<char>();
char* target = ptr.data();
memcpy(&target[tensor_offset], source, n);
memcpy(target, source, n);
}
void runtime::intelgpu::IntelGPUTensorView::read(void* target, size_t tensor_offset, size_t n) const
void runtime::intelgpu::IntelGPUTensorView::read(void* target, size_t n) const
{
if (tensor_offset + n > ocl_memory->size())
if (n > ocl_memory->size())
{
throw out_of_range("read access past end of tensor");
}
const auto ptr = ocl_memory->pointer<char>();
const char* source = ptr.data();
memcpy(target, &source[tensor_offset], n);
memcpy(target, source, n);
}
......@@ -42,15 +42,13 @@ public:
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t tensor_offset, size_t n) override;
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t tensor_offset, size_t n) const override;
void read(void* p, size_t n) const override;
cldnn::memory* get_data_ptr() { return ocl_memory.get(); }
private:
......
......@@ -40,20 +40,19 @@ ngraph::runtime::plaidml::PlaidML_Tensor::PlaidML_Tensor(Config* config,
<< " type=" << element_type << " shape=" << shape;
}
void ngraph::runtime::plaidml::PlaidML_Tensor::write(const void* p, size_t tensor_offset, size_t n)
void ngraph::runtime::plaidml::PlaidML_Tensor::write(const void* p, size_t n)
{
NGRAPH_DEBUG << "Write " << this << " offset=" << tensor_offset << " n=" << n
<< " is_logically_zero=" << m_is_logically_zero;
NGRAPH_DEBUG << "Write " << this << " n=" << n << " is_logically_zero=" << m_is_logically_zero;
// As a special case: if we get a zero-sized write to offset zero, fill the tensor with zero.
if (n == 0 && tensor_offset == 0)
if (n == 0)
{
NGRAPH_DEBUG << "Logically zeroing tensor " << this;
m_is_logically_zero = true;
return;
}
bool is_full_write = (tensor_offset == 0 && n == m_tensor.get_shape().buffer_size());
bool is_full_write = (n == m_tensor.get_shape().buffer_size());
vp::mapping<char> mp;
if (m_is_logically_zero || is_full_write)
......@@ -77,14 +76,13 @@ void ngraph::runtime::plaidml::PlaidML_Tensor::write(const void* p, size_t tenso
m_is_logically_zero = false;
const char* src = static_cast<const char*>(p);
char* dest = mp.raw() + tensor_offset;
char* dest = mp.raw();
std::copy(src, src + n, dest);
}
void ngraph::runtime::plaidml::PlaidML_Tensor::read(void* p, size_t tensor_offset, size_t n) const
void ngraph::runtime::plaidml::PlaidML_Tensor::read(void* p, size_t n) const
{
NGRAPH_DEBUG << "Read " << this << " offset=" << tensor_offset << " n=" << n
<< " is_logically_zero=" << m_is_logically_zero;
NGRAPH_DEBUG << "Read " << this << " n=" << n << " is_logically_zero=" << m_is_logically_zero;
char* dest = static_cast<char*>(p);
......@@ -95,7 +93,7 @@ void ngraph::runtime::plaidml::PlaidML_Tensor::read(void* p, size_t tensor_offse
}
vp::mapping<char> mp = m_tensor.map(vp::map_for_read);
const char* src = mp.raw() + tensor_offset;
const char* src = mp.raw();
std::copy(src, src + n, dest);
}
......
......@@ -42,8 +42,8 @@ public:
void* memory);
~PlaidML_Tensor() final {}
const vertexai::plaidml::tensor<char>& tensor() const { return m_tensor; }
void write(const void* p, size_t tensor_offset, size_t n) final;
void read(void* p, size_t tensor_offset, size_t n) const final;
void write(const void* p, size_t n) final;
void read(void* p, size_t n) const final;
// Copy the backing memory to the tensor, if needed.
void sync_input();
......
......@@ -92,6 +92,6 @@ void runtime::Tensor::copy_from(const ngraph::runtime::Tensor& source)
// This is be replaced with more optimial implementations in later PRs
auto size = get_size_in_bytes();
AlignedBuffer buffer{size, 64};
source.read(buffer.get_ptr(), 0, size);
write(buffer.get_ptr(), 0, size);
source.read(buffer.get_ptr(), size);
write(buffer.get_ptr(), size);
}
......@@ -95,20 +95,40 @@ namespace ngraph
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
virtual void write(const void* p, size_t offset, size_t n) = 0;
virtual void write(const void* p, size_t n) = 0;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
virtual void read(void* p, size_t offset, size_t n) const = 0;
virtual void read(void* p, size_t n) const = 0;
/// \brief copy bytes directly from source to this tensor
/// \param source The source tensor
virtual void copy_from(const ngraph::runtime::Tensor& source);
NGRAPH_DEPRECATED_DOC
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t offset, size_t n)
NGRAPH_DEPRECATED("Use two-parameter write")
{
write(p, n);
}
NGRAPH_DEPRECATED_DOC
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param offset Offset into tensor storage to begin writing. Must be element-aligned.
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t offset, size_t n) const
NGRAPH_DEPRECATED("Use two-parameter read")
{
read(p, n);
}
protected:
std::shared_ptr<ngraph::descriptor::Tensor> m_descriptor;
bool m_stale;
......
......@@ -51,7 +51,7 @@ void init_int_tv(shared_ptr<runtime::Tensor> tv, T min, T max)
{
element = dist(s_random_engine);
}
tv->write(vec.data(), 0, vec.size() * sizeof(T));
tv->write(vec.data(), vec.size() * sizeof(T));
}
template <>
......@@ -64,7 +64,7 @@ void init_int_tv<char>(shared_ptr<runtime::Tensor> tv, char min, char max)
{
element = static_cast<char>(dist(s_random_engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(char));
tv->write(vec.data(), vec.size() * sizeof(char));
}
template <>
......@@ -77,7 +77,7 @@ void init_int_tv<int8_t>(shared_ptr<runtime::Tensor> tv, int8_t min, int8_t max)
{
element = static_cast<int8_t>(dist(s_random_engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(int8_t));
tv->write(vec.data(), vec.size() * sizeof(int8_t));
}
template <>
......@@ -90,7 +90,7 @@ void init_int_tv<uint8_t>(shared_ptr<runtime::Tensor> tv, uint8_t min, uint8_t m
{
element = static_cast<uint8_t>(dist(s_random_engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(uint8_t));
tv->write(vec.data(), vec.size() * sizeof(uint8_t));
}
template <typename T>
......@@ -103,7 +103,7 @@ void init_real_tv(shared_ptr<runtime::Tensor> tv, T min, T max)
{
element = dist(s_random_engine);
}
tv->write(vec.data(), 0, vec.size() * sizeof(T));
tv->write(vec.data(), vec.size() * sizeof(T));
}
static void random_init(shared_ptr<runtime::Tensor> tv)
......@@ -151,7 +151,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
make_shared<runtime::HostTensor>(param->get_element_type(), param->get_shape());
random_init(tensor_data);
tensor->write(tensor_data->get_data_ptr(),
0,
tensor_data->get_element_count() * tensor_data->get_element_type().size());
args.push_back(tensor);
arg_data.push_back(tensor_data);
......@@ -194,7 +193,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
{
const shared_ptr<runtime::HostTensor>& data = arg_data[arg_index];
arg->write(data->get_data_ptr(),
0,
data->get_element_count() * data->get_element_type().size());
}
}
......@@ -207,7 +205,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
const shared_ptr<runtime::HostTensor>& data = result_data[result_index];
const shared_ptr<runtime::Tensor>& result = results[result_index];
result->read(data->get_data_ptr(),
0,
data->get_element_count() * data->get_element_type().size());
}
}
......
......@@ -382,10 +382,10 @@ NGRAPH_TEST_P(${BACKEND_NAME}, serialized_graph_files, compare_backends_with_gra
random_init(data.get(), engine);
auto ref_tensor = ref->create_tensor(param->get_element_type(), param->get_shape());
auto bk_tensor = backend->create_tensor(param->get_element_type(), param->get_shape());
ref_tensor->write(
data->get_data_ptr(), 0, data->get_element_count() * data->get_element_type().size());
bk_tensor->write(
data->get_data_ptr(), 0, data->get_element_count() * data->get_element_type().size());
ref_tensor->write(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
bk_tensor->write(data->get_data_ptr(),
data->get_element_count() * data->get_element_type().size());
ref_args.push_back(ref_tensor);
bk_args.push_back(bk_tensor);
}
......@@ -459,8 +459,8 @@ NGRAPH_TEST_P(${BACKEND_NAME}, serialized_graph_files, compare_backends_with_gra
auto bk_tensor =
backend->create_tensor(param->get_element_type(), param->get_shape());
size_t size_in_bytes = ref_tensor->get_size_in_bytes();
ref_tensor->read(data->get_data_ptr(), 0, size_in_bytes);
bk_tensor->write(data->get_data_ptr(), 0, size_in_bytes);
ref_tensor->read(data->get_data_ptr(), size_in_bytes);
bk_tensor->write(data->get_data_ptr(), size_in_bytes);
bk_args.push_back(bk_tensor);
}
}
......
......@@ -73,39 +73,6 @@ TEST(tensor, size)
}
}
template <typename T>
void test_read_write(const vector<T>& x)
{
auto backend = runtime::Backend::create("INTERPRETER");
auto a = backend->create_tensor(element::from<T>(), Shape{2, x.size()});
vector<T> result(2 * x.size());
a->write(&x[0], 0, x.size() * sizeof(T));
copy(x.begin(), x.end(), result.begin());
a->write(&x[0], x.size() * sizeof(T), x.size() * sizeof(T));
copy(x.begin(), x.end(), result.begin() + x.size());
vector<T> af_vector(2 * x.size());
a->read(af_vector.data(), 0, af_vector.size() * sizeof(T));
ASSERT_EQ(af_vector, result);
vector<T> result1(x.size());
vector<T> result2(x.size());
copy(result.begin() + 1, result.begin() + 1 + x.size(), result1.begin());
a->read(&result2[0], sizeof(T), sizeof(T) * x.size());
ASSERT_EQ(result1, result2);
}
#if defined(NGRAPH_INTERPRETER_ENABLE)
TEST(tensor, read_write)
{
test_read_write<float>({1.0, 3.0, 5.0});
test_read_write<int64_t>({-1, 2, 4});
}
#endif
TEST(tensor, output_flag)
{
pass::Manager pass_manager;
......
......@@ -195,7 +195,7 @@ void init_int_tv<char>(ngraph::runtime::Tensor* tv,
{
element = static_cast<char>(dist(engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(char));
tv->write(vec.data(), vec.size() * sizeof(char));
}
template <>
......@@ -211,7 +211,7 @@ void init_int_tv<int8_t>(ngraph::runtime::Tensor* tv,
{
element = static_cast<int8_t>(dist(engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(int8_t));
tv->write(vec.data(), vec.size() * sizeof(int8_t));
}
template <>
......@@ -227,7 +227,7 @@ void init_int_tv<uint8_t>(ngraph::runtime::Tensor* tv,
{
element = static_cast<uint8_t>(dist(engine));
}
tv->write(vec.data(), 0, vec.size() * sizeof(uint8_t));
tv->write(vec.data(), vec.size() * sizeof(uint8_t));
}
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine)
......
......@@ -48,7 +48,7 @@ template <typename T>
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data)
{
size_t data_size = data.size() * sizeof(T);
tv->write(data.data(), 0, data_size);
tv->write(data.data(), data_size);
}
template <typename T>
......@@ -61,7 +61,7 @@ std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::Tensor> tv)
size_t element_count = ngraph::shape_size(tv->get_shape());
size_t size = element_count * sizeof(T);
std::vector<T> rc(element_count);
tv->read(rc.data(), 0, size);
tv->read(rc.data(), size);
return rc;
}
......@@ -70,7 +70,7 @@ std::vector<float> read_float_vector(std::shared_ptr<ngraph::runtime::Tensor> tv
template <typename T>
void write_vector(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& values)
{
tv->write(values.data(), 0, values.size() * sizeof(T));
tv->write(values.data(), values.size() * sizeof(T));
}
template <typename T>
......@@ -113,7 +113,7 @@ void init_int_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engine
{
element = dist(engine);
}
tv->write(vec.data(), 0, vec.size() * sizeof(T));
tv->write(vec.data(), vec.size() * sizeof(T));
}
template <typename T>
......@@ -126,7 +126,7 @@ void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engin
{
element = dist(engine);
}
tv->write(vec.data(), 0, vec.size() * sizeof(T));
tv->write(vec.data(), vec.size() * sizeof(T));
}
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment