Unverified Commit 8bab36fb authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge descriptor::TensorView into descriptor::Tensor (#1536)

* Merge descriptor::TensorView into descriptor::Tensot

* fix GPU build
parent 62e470b2
......@@ -29,7 +29,6 @@ set (SRC
descriptor/layout/tensor_view_layout.cpp
descriptor/output.cpp
descriptor/tensor.cpp
descriptor/tensor_view.cpp
file_util.cpp
function.cpp
log.cpp
......
......@@ -25,7 +25,7 @@ namespace ngraph
{
namespace descriptor
{
class TensorView;
class Tensor;
namespace layout
{
......@@ -36,7 +36,7 @@ namespace ngraph
{
public:
~DenseTensorViewLayout() override {}
DenseTensorViewLayout(const TensorView& tensor_view);
DenseTensorViewLayout(const Tensor& tensor);
virtual size_t get_size() override { return m_size; }
size_t get_offset() const { return m_offset; }
......
......@@ -15,7 +15,7 @@
//*****************************************************************************
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/type/element_type.hpp"
using namespace ngraph;
......
......@@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph
{
......
......@@ -21,10 +21,10 @@
using namespace std;
using namespace ngraph;
descriptor::Output::Output(Node* node, size_t index, const shared_ptr<TensorView>& tensor_view)
descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& tensor)
: m_node(node)
, m_index(index)
, m_tensor_view(tensor_view)
, m_tensor(tensor)
{
}
......@@ -46,15 +46,15 @@ shared_ptr<Node> descriptor::Output::get_node() const
descriptor::Tensor& descriptor::Output::get_tensor() const
{
return m_tensor_view->get_tensor();
return *m_tensor;
}
const Shape& descriptor::Output::get_shape() const
{
return m_tensor_view->get_shape();
return m_tensor->get_shape();
}
const element::Type& descriptor::Output::get_element_type() const
{
return m_tensor_view->get_element_type();
return m_tensor->get_element_type();
}
......@@ -20,7 +20,7 @@
#include <set>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph
{
......@@ -39,16 +39,13 @@ namespace ngraph
public:
/// \param node Node that owns this output.
/// \param index Position of the output tensor in all output tensors
/// \param tensor_view The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<TensorView>& tensor_view);
/// \param tensor The view of this tensor; where the value will be written
Output(Node* node, size_t index, const std::shared_ptr<Tensor>& tensor);
std::shared_ptr<Node> get_node() const;
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void set_tensor_view(const std::shared_ptr<TensorView>& tensor_view)
{
m_tensor_view = tensor_view;
}
std::shared_ptr<Tensor> get_tensor_view() const { return m_tensor; }
void set_tensor_view(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
void add_input(Input* input);
void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
......@@ -62,7 +59,7 @@ namespace ngraph
protected:
Node* m_node;
size_t m_index;
std::shared_ptr<TensorView> m_tensor_view;
std::shared_ptr<Tensor> m_tensor;
std::set<Input*> m_inputs;
private:
......
......@@ -22,23 +22,22 @@ using namespace ngraph;
using namespace std;
descriptor::Tensor::Tensor(const element::Type& element_type,
TensorView* tensor_view,
const string& name)
const Shape& shape,
const std::string& name)
: m_element_type(element_type)
, m_tensor_view(tensor_view)
, m_name{name}
, m_next_view_id{0}
, m_shape(shape)
, m_name(name)
{
}
string descriptor::Tensor::make_tensor_name(const Node* node, size_t value_index)
void descriptor::Tensor::set_tensor_view_type(const element::Type& element_type, const Shape& shape)
{
return node->get_name() + "_" + to_string(value_index);
}
string descriptor::Tensor::get_next_view_name()
{
return m_name + "_TV" + to_string(m_next_view_id++);
m_shape = shape;
m_element_type = element_type;
if (nullptr != m_tensor_view_layout)
{
m_tensor_view_layout->set_tensor_view_type(element_type, shape);
}
}
void descriptor::Tensor::set_pool_offset(size_t offset)
......@@ -53,21 +52,16 @@ size_t descriptor::Tensor::get_pool_offset() const
size_t descriptor::Tensor::size() const
{
if (auto tvl = m_tensor_view->get_tensor_view_layout())
if (auto tvl = get_tensor_view_layout())
{
return tvl->get_allocated_size();
}
else
{
return shape_size(m_tensor_view->get_shape()) * m_element_type.size();
return shape_size(get_shape()) * m_element_type.size();
}
}
void descriptor::Tensor::set_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
......
......@@ -16,10 +16,10 @@
#pragma once
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -27,44 +27,58 @@ namespace ngraph
{
class Node;
namespace element
{
class Type;
}
namespace descriptor
{
class TensorView;
class Tensor;
}
}
namespace layout
{
class TensorViewLayout;
}
class ngraph::descriptor::Tensor
{
friend class TensorView;
/// \brief Compile-time descriptor of a first-class value that is a view of a tensor.
class Tensor
{
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
private:
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
public:
Tensor(const element::Type& element_type, const Shape& shape, const std::string& name);
Tensor(const element::Type& element_type, TensorView* tensor_view, const std::string& name);
std::string get_next_view_name();
const std::string& get_name() const { return m_name; }
void set_tensor_view_type(const element::Type& element_type, const Shape& shape);
public:
const std::string& get_name() const { return m_name; }
void set_pool_offset(size_t);
size_t size() const;
size_t get_pool_offset() const;
const element::Type& get_element_type() const { return m_element_type; }
void set_element_type(const element::Type& element_type);
static std::string make_tensor_name(const Node* node, size_t value_index);
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; }
const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
}
protected:
element::Type m_element_type;
TensorView* m_tensor_view;
std::string m_name;
size_t m_next_view_id;
size_t m_pool_offset;
};
void set_tensor_view_layout(
const std::shared_ptr<layout::TensorViewLayout>& tensor_view_layout)
{
m_tensor_view_layout = tensor_view_layout;
}
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
void set_pool_offset(size_t);
size_t get_pool_offset() const;
size_t size() const;
const Tensor& get_tensor() const { return *this; }
Tensor& get_tensor() { return *this; }
const Tensor& get_tensor_view() const { return *this; }
Tensor& get_tensor_view() { return *this; }
protected:
element::Type m_element_type;
Shape m_shape;
std::string m_name;
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
size_t m_pool_offset;
};
using TensorView = Tensor;
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
using namespace ngraph;
using namespace std;
descriptor::TensorView::TensorView(const element::Type& element_type,
const Shape& shape,
const std::string& name)
: m_element_type(element_type)
, m_shape(shape)
, m_tensor(m_element_type, this, name)
{
// Set the name in the parent TensorView.
// This can't be done until after the m_tensor is constructed.
m_name = m_tensor.get_next_view_name();
}
const element::Type& descriptor::TensorView::get_element_type() const
{
return m_element_type;
}
const Shape& descriptor::TensorView::get_shape() const
{
return m_shape;
}
const descriptor::Tensor& descriptor::TensorView::get_tensor() const
{
return m_tensor;
}
descriptor::Tensor& descriptor::TensorView::get_tensor()
{
return m_tensor;
}
void descriptor::TensorView::set_tensor_view_type(const element::Type& element_type,
const Shape& shape)
{
m_shape = shape;
m_element_type = element_type;
m_tensor.set_element_type(element_type);
if (nullptr != m_tensor_view_layout)
{
m_tensor_view_layout->set_tensor_view_type(element_type, shape);
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <string>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
class Node;
namespace descriptor
{
namespace layout
{
class Tensor;
class TensorViewLayout;
}
/// \brief Compile-time descriptor of a first-class value that is a view of a tensor.
class TensorView
{
TensorView(const TensorView&) = delete;
TensorView& operator=(const TensorView&) = delete;
public:
TensorView(const element::Type& element_type,
const Shape& shape,
const std::string& name);
const Tensor& get_tensor() const;
Tensor& get_tensor();
const std::string& get_name() const { return m_name; }
void set_tensor_view_type(const element::Type& element_type, const Shape& shape);
const element::Type& get_element_type() const;
const Shape& get_shape() const;
const std::shared_ptr<layout::TensorViewLayout>& get_tensor_view_layout() const
{
return m_tensor_view_layout;
}
void set_tensor_view_layout(
const std::shared_ptr<layout::TensorViewLayout>& tensor_view_layout)
{
m_tensor_view_layout = tensor_view_layout;
}
protected:
element::Type m_element_type;
Shape m_shape;
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
std::string m_name;
Tensor m_tensor;
};
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
}
}
......@@ -53,7 +53,6 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
......
......@@ -75,7 +75,7 @@ void Node::set_output_size(size_t n)
for (size_t i = m_outputs.size(); i < n; ++i)
{
auto tensor_view_descriptor = make_shared<descriptor::TensorView>(
element::unspecified, Shape(), ngraph::descriptor::Tensor::make_tensor_name(this, i));
element::unspecified, Shape(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_view_descriptor);
}
}
......
......@@ -18,7 +18,7 @@
#include <memory>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
......
......@@ -23,36 +23,36 @@ using namespace ngraph;
runtime::gpu::GPU_TensorViewWrapper::GPU_TensorViewWrapper(
const shared_ptr<descriptor::TensorView>& tv, const string& alias)
: m_tensor_view(tv)
: m_tensor(tv)
, m_alias(alias)
{
}
size_t runtime::gpu::GPU_TensorViewWrapper::get_size() const
{
return m_tensor_view->get_tensor_view_layout()->get_size();
return m_tensor->get_tensor_view_layout()->get_size();
}
const Shape& runtime::gpu::GPU_TensorViewWrapper::get_shape() const
{
return m_tensor_view->get_tensor_view_layout()->get_shape();
return m_tensor->get_tensor_view_layout()->get_shape();
}
const Strides& runtime::gpu::GPU_TensorViewWrapper::get_strides() const
{
return m_tensor_view->get_tensor_view_layout()->get_strides();
return m_tensor->get_tensor_view_layout()->get_strides();
}
const element::Type& runtime::gpu::GPU_TensorViewWrapper::get_element_type() const
{
return m_tensor_view->get_tensor_view_layout()->get_element_type();
return m_tensor->get_tensor_view_layout()->get_element_type();
}
const std::string& runtime::gpu::GPU_TensorViewWrapper::get_name() const
{
if (m_alias.empty())
{
return m_tensor_view->get_tensor().get_name();
return m_tensor->get_tensor().get_name();
}
else
{
......
......@@ -18,7 +18,7 @@
#include <memory>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
......@@ -35,7 +35,7 @@ namespace ngraph
class ngraph::runtime::gpu::GPU_TensorViewWrapper
{
public:
GPU_TensorViewWrapper(const std::shared_ptr<descriptor::TensorView>&,
GPU_TensorViewWrapper(const std::shared_ptr<descriptor::Tensor>&,
const std::string& alias = "");
size_t get_size() const;
......@@ -46,6 +46,6 @@ public:
const std::string& get_type() const;
private:
std::shared_ptr<descriptor::TensorView> m_tensor_view;
std::shared_ptr<descriptor::Tensor> m_tensor;
std::string m_alias;
};
......@@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -36,7 +36,7 @@ namespace ngraph
class TensorView
{
protected:
TensorView(const std::shared_ptr<ngraph::descriptor::TensorView>& descriptor)
TensorView(const std::shared_ptr<ngraph::descriptor::Tensor>& descriptor)
: m_descriptor(descriptor)
, m_stale(true)
{
......@@ -46,8 +46,7 @@ namespace ngraph
virtual ~TensorView() {}
TensorView& operator=(const TensorView&) = default;
std::shared_ptr<const ngraph::descriptor::TensorView>
get_tensor_view_descriptor() const;
std::shared_ptr<const ngraph::descriptor::Tensor> get_tensor_view_descriptor() const;
virtual std::shared_ptr<descriptor::TensorView> get_descriptor() const;
......@@ -74,7 +73,7 @@ namespace ngraph
virtual void read(void* p, size_t tensor_offset, size_t n) const = 0;
protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
std::shared_ptr<ngraph::descriptor::Tensor> m_descriptor;
bool m_stale;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment