Commit c405c3bc authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 93a2efda
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/host_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -30,6 +30,6 @@ std::shared_ptr<ngraph::runtime::TensorView> ...@@ -30,6 +30,6 @@ std::shared_ptr<ngraph::runtime::TensorView>
runtime::cpu::CPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type, runtime::cpu::CPU_Backend::make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape) const Shape& shape)
{ {
auto rc = make_shared<runtime::HostTensorView>(element_type, shape); auto rc = make_shared<runtime::cpu::CPUTensorView>(element_type, shape);
return dynamic_pointer_cast<runtime::TensorView>(rc); return dynamic_pointer_cast<runtime::TensorView>(rc);
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ngraph/runtime/cpu/cpu_call_frame.hpp" #include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/host_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -36,14 +36,14 @@ void runtime::cpu::CPU_CallFrame::tensor_call( ...@@ -36,14 +36,14 @@ void runtime::cpu::CPU_CallFrame::tensor_call(
vector<void*> outputs; vector<void*> outputs;
for (size_t i = 0; i < input_tvs.size(); i++) for (size_t i = 0; i < input_tvs.size(); i++)
{ {
shared_ptr<runtime::HostTensorView> tv = shared_ptr<runtime::cpu::CPUTensorView> tv =
static_pointer_cast<runtime::HostTensorView>(input_tvs[i]); static_pointer_cast<runtime::cpu::CPUTensorView>(input_tvs[i]);
inputs.push_back(tv->get_data_ptr()); inputs.push_back(tv->get_data_ptr());
} }
for (size_t i = 0; i < output_tvs.size(); i++) for (size_t i = 0; i < output_tvs.size(); i++)
{ {
shared_ptr<runtime::HostTensorView> tv = shared_ptr<runtime::cpu::CPUTensorView> tv =
static_pointer_cast<runtime::HostTensorView>(output_tvs[i]); static_pointer_cast<runtime::cpu::CPUTensorView>(output_tvs[i]);
outputs.push_back(tv->get_data_ptr()); outputs.push_back(tv->get_data_ptr());
} }
......
...@@ -15,13 +15,20 @@ ...@@ -15,13 +15,20 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include "cpu_tensor_view.hpp" #include "ngraph/except.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp" #include "ngraph/descriptor/primary_tensor_view.hpp"
#include "cpu_tensor_view.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
// TODO(jmenon): Refactor all the alignment specifications into
// a single place and allow lower or no alignment when possible
const size_t runtime::cpu::CPUTensorView::BufferAlignment = 64;
runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_type, runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_type,
const Shape& shape, const Shape& shape,
const string& name) const string& name)
...@@ -30,6 +37,19 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_ ...@@ -30,6 +37,19 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_
, buffer(nullptr) , buffer(nullptr)
, aligned_buffer(nullptr) , aligned_buffer(nullptr)
{ {
buffer_size = shape_size(shape) * element_type.size();
if (buffer_size)
{
size_t allocation_size = buffer_size + BufferAlignment;
auto ptr = malloc(allocation_size);
if (!ptr)
{
throw ngraph_error("Error allocating CPU Tensor View memory");
}
buffer = static_cast<char*>(ptr);
std::align(BufferAlignment, buffer_size, ptr, allocation_size);
aligned_buffer = static_cast<char*>(ptr);
}
} }
runtime::cpu::CPUTensorView::~CPUTensorView() runtime::cpu::CPUTensorView::~CPUTensorView()
...@@ -69,10 +89,5 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_ ...@@ -69,10 +89,5 @@ void runtime::cpu::CPUTensorView::read(void* target, size_t tensor_offset, size_
size_t runtime::cpu::CPUTensorView::get_size() const size_t runtime::cpu::CPUTensorView::get_size() const
{ {
return get_tensor_view_layout()->get_size(); return get_element_count();
}
const element::Type& runtime::cpu::CPUTensorView::get_element_type() const
{
return get_tensor_view_layout()->get_element_type();
} }
...@@ -52,6 +52,8 @@ namespace ngraph ...@@ -52,6 +52,8 @@ namespace ngraph
void read(void* p, size_t tensor_offset, size_t n) const override; void read(void* p, size_t tensor_offset, size_t n) const override;
private: private:
static const size_t BufferAlignment;
char* buffer; char* buffer;
char* aligned_buffer; char* aligned_buffer;
size_t buffer_size; size_t buffer_size;
......
...@@ -47,6 +47,11 @@ const Strides& runtime::TensorView::get_strides() const ...@@ -47,6 +47,11 @@ const Strides& runtime::TensorView::get_strides() const
return m_descriptor->get_tensor_view_layout()->get_strides(); return m_descriptor->get_tensor_view_layout()->get_strides();
} }
const element::Type& runtime::TensorView::get_element_type() const
{
return m_descriptor->get_tensor_view_type()->get_element_type();
}
shared_ptr<descriptor::layout::TensorViewLayout> runtime::TensorView::get_tensor_view_layout() const shared_ptr<descriptor::layout::TensorViewLayout> runtime::TensorView::get_tensor_view_layout() const
{ {
return m_descriptor->get_tensor_view_layout(); return m_descriptor->get_tensor_view_layout();
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -52,6 +53,7 @@ namespace ngraph ...@@ -52,6 +53,7 @@ namespace ngraph
const ngraph::Shape& get_shape() const; const ngraph::Shape& get_shape() const;
const ngraph::Strides& get_strides() const; const ngraph::Strides& get_strides() const;
const ngraph::element::Type& get_element_type() const;
size_t get_element_count() const; size_t get_element_count() const;
const ngraph::descriptor::Tensor& get_tensor() const; const ngraph::descriptor::Tensor& get_tensor() const;
......
...@@ -40,7 +40,7 @@ void copy_data(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::vecto ...@@ -40,7 +40,7 @@ void copy_data(std::shared_ptr<ngraph::runtime::TensorView> tv, const std::vecto
template <typename T> template <typename T>
std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::TensorView> tv) std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::TensorView> tv)
{ {
if (ngraph::element::from<T>() != tv->get_tensor_view_layout()->get_element_type()) if (ngraph::element::from<T>() != tv->get_element_type())
{ {
throw std::invalid_argument("read_vector type must match TensorView type"); throw std::invalid_argument("read_vector type must match TensorView type");
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment