Unverified Commit f6eec91f authored by Yixing Lao's avatar Yixing Lao Committed by GitHub

Clean up is_constant is_input is_output in Tensor (#943)

* remove in tensor.hpp and cpp

* remove in constructors

* more clean up at tv_wrapper and set_is_output()

* fix liveness

* fix liveness.cpp

* finally fixed liveness

* fix PrimaryTensorView constructor in node.cpp

* fix PrimaryTensorView constructor in cpu_tensor_view

* clang-format

* update tensor print

* clean comments

* rename
parent b88fa59d
......@@ -50,8 +50,6 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; }
Tensor& get_tensor() const;
void set_is_output() { get_tensor().set_is_output(); }
bool is_output() { return get_tensor().is_output(); }
protected:
/// @return the tensor view type for the output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
......
......@@ -20,12 +20,9 @@ using namespace ngraph;
using namespace descriptor;
PrimaryTensorView::PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const std::string& name,
bool is_output,
bool is_input,
bool is_constant)
const std::string& name)
: TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this, name, is_output, is_input, is_constant)
, m_tensor(tensor_view_type->get_element_type(), this, name)
{
// Set the name in the parent TensorView.
// This can't be done until after the m_tensor is constructed.
......
......@@ -34,13 +34,8 @@ namespace ngraph
public:
/// @param tensor_view_type The type for this view.
/// @param name Description of the tensor, for debugging.
/// @param is_output The view can be read from the host at the end of a computation.
/// @param is_input The view can be written from the host at the beginning of a computation.
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const std::string& name,
bool is_output,
bool is_input,
bool is_constant);
const std::string& name);
virtual const Tensor& get_tensor() const override;
virtual Tensor& get_tensor() override;
......
......@@ -23,16 +23,9 @@ using namespace std;
descriptor::Tensor::Tensor(const element::Type& element_type,
PrimaryTensorView* primary_tensor_view,
const string& name,
bool is_output,
bool is_input,
bool is_constant)
const string& name)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
, m_is_output{is_output}
, m_is_input{is_input}
, m_is_persistent{false}
, m_is_constant{is_constant}
, m_name{name}
, m_next_view_id{0}
{
......@@ -71,18 +64,6 @@ size_t descriptor::Tensor::get_pool_offset() const
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ", ";
out << (tensor.is_persistent() ? "P" : "");
out << (tensor.is_constant() ? "C" : "");
out << (tensor.is_input() ? "I" : "");
out << (tensor.is_output() ? "O" : "");
if (!tensor.is_persistent() && !tensor.is_constant() && !tensor.is_input() &&
!tensor.is_output())
{
out << "T";
}
out << ")";
out << "Tensor(" << tensor.get_name() << ")";
return out;
}
......@@ -48,32 +48,21 @@ private:
Tensor(const element::Type& element_type,
PrimaryTensorView* tensor_view,
const std::string& name,
bool is_output,
bool is_input,
bool is_constant);
const std::string& name);
std::string get_next_view_name();
public:
bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; }
bool is_constant() const { return m_is_constant; }
const std::string& get_name() const { return m_name; }
size_t size() const;
void set_pool_offset(size_t);
size_t get_pool_offset() const;
const element::Type& get_element_type() const { return m_element_type; }
static std::string make_tensor_name(const Node* node, size_t value_index);
void set_is_output() { m_is_output = true; }
protected:
const element::Type m_element_type;
PrimaryTensorView* m_primary_tensor_view;
bool m_is_output;
bool m_is_input;
bool m_is_persistent;
bool m_is_constant;
std::string m_name;
size_t m_next_view_id;
size_t m_size;
......
......@@ -74,14 +74,6 @@ Function::Function(const std::shared_ptr<Node>& result,
void Function::init()
{
for (auto r : m_results)
{
for (descriptor::Output& output : r->get_outputs())
{
output.get_tensor().set_is_output();
}
}
traverse_nodes(this, [&](shared_ptr<Node> node) {
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
......
......@@ -68,11 +68,7 @@ void Node::add_output(const element::Type& element_type, const Shape& shape)
shared_ptr<TensorViewType> tensor_view_type = make_shared<TensorViewType>(element_type, shape);
size_t i = m_outputs.size();
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(
tensor_view_type,
ngraph::descriptor::Tensor::make_tensor_name(this, i),
false,
is_parameter(),
is_constant());
tensor_view_type, ngraph::descriptor::Tensor::make_tensor_name(this, i));
m_outputs.emplace_back(this, i, tensor_view_descriptor);
}
......
......@@ -20,19 +20,55 @@
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
bool pass::Liveness::run_on_function(shared_ptr<ngraph::Function> function)
{
unordered_set<descriptor::Tensor*> currently_live;
list<shared_ptr<Node>> ops = function->get_ordered_ops();
unordered_set<descriptor::Tensor*> persistent_tensors;
unordered_set<descriptor::Tensor*> output_tensors;
for (shared_ptr<op::Parameter> node : function->get_parameters())
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
}
}
for (shared_ptr<op::Result> node : function->get_results())
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
output_tensors.insert(&tensor);
}
}
for (shared_ptr<Node> node : function->get_ordered_ops())
{
if (auto constant_node = dynamic_pointer_cast<op::Constant>(node))
{
for (size_t i = 0; i < constant_node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = constant_node->get_output_tensor(i);
persistent_tensors.insert(&tensor);
}
}
}
unordered_set<descriptor::Tensor*> currently_live;
for (auto it = ops.rbegin(); it != ops.rend(); it++)
{
shared_ptr<Node> node = *it;
......@@ -43,7 +79,7 @@ bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
for (descriptor::Input& input_decl : node->get_inputs())
{
descriptor::Tensor& tensor = input_decl.get_tensor();
if (is_temporary(tensor))
if (!contains(persistent_tensors, &tensor))
{
input_tensor_decls.insert(&tensor);
}
......@@ -53,7 +89,7 @@ bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
for (size_t i = 0; i < node->get_output_size(); ++i)
{
descriptor::Tensor& tensor = node->get_output_tensor(i);
if (is_temporary(tensor))
if (!contains(persistent_tensors, &tensor))
{
output_tensor_decls.insert(&tensor);
}
......@@ -96,7 +132,7 @@ bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
{
for (descriptor::Tensor* tensor : node->liveness_live_list)
{
if (tensor->is_output())
if (contains(output_tensors, tensor))
{
outputs.insert(tensor);
}
......@@ -124,13 +160,6 @@ bool pass::Liveness::run_on_call_graph(const list<shared_ptr<Node>>& ops)
return false;
}
bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
{
return tensor.is_persistent() == false && tensor.is_input() == false &&
tensor.is_output() == false && tensor.is_constant() == false;
// && tensor.is_compile_only() == false;
}
void pass::Liveness::validate_liveness(const list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> dead_tensors;
......
......@@ -27,12 +27,11 @@ namespace ngraph
}
}
class ngraph::pass::Liveness : public CallGraphPass
class ngraph::pass::Liveness : public FunctionPass
{
public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>&) override;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
private:
bool is_temporary(const descriptor::Tensor&);
void validate_liveness(const std::list<Node*>& ops);
};
......@@ -39,10 +39,6 @@ ngraph::pass::Manager::Manager()
}
}
ngraph::pass::Manager::Manager(bool to_set_is_output)
{
}
ngraph::pass::Manager::~Manager()
{
}
......
......@@ -37,7 +37,6 @@ class ngraph::pass::Manager
{
public:
Manager();
Manager(bool to_set_is_output);
~Manager();
void initialize_default_passes();
......
......@@ -70,10 +70,7 @@ bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<ngraph::Function>>&
}
for (descriptor::Tensor* tensor : tensors)
{
if (tensor->is_persistent() == false)
{
temp_max_size += tensor->size();
}
temp_max_size += tensor->size();
}
// file << "<table>\n";
......@@ -244,27 +241,13 @@ void pass::MemoryVisualize::draw_op_influence(ostream& file, const list<shared_p
int pass::MemoryVisualize::compute_op_weight(const shared_ptr<Node> exop)
{
int mass = 0;
// for input_decl in exop.input_decls:
// tensor = input_decl.source_output_decl.tensor
// if tensor.is_persistent is False:
// mass += tensor->size()
// for output_decl in exop.output_decls:
// tensor = output_decl.tensor
// if tensor.is_persistent is False:
// mass -= tensor->size()
for (const descriptor::Tensor* tensor : exop->liveness_new_list)
{
if (tensor->is_persistent() == false)
{
mass += tensor->size();
}
mass += tensor->size();
}
for (const descriptor::Tensor* tensor : exop->liveness_free_list)
{
if (tensor->is_persistent() == false)
{
mass -= tensor->size();
}
mass -= tensor->size();
}
return mass;
}
......
......@@ -39,7 +39,7 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_
void* memory_pointer,
const string& name)
: runtime::TensorView(std::make_shared<ngraph::descriptor::PrimaryTensorView>(
std::make_shared<ngraph::TensorViewType>(element_type, shape), name, true, true, false))
std::make_shared<ngraph::TensorViewType>(element_type, shape), name))
, buffer(nullptr)
, aligned_buffer(nullptr)
{
......
......@@ -65,11 +65,6 @@ const std::string& runtime::cpu::TensorViewWrapper::get_type() const
return get_element_type().c_type_string();
}
bool runtime::cpu::TensorViewWrapper::is_output() const
{
return m_tensor_view->get_tensor().is_output();
}
const std::shared_ptr<descriptor::TensorView>
runtime::cpu::TensorViewWrapper::get_tensor_view() const
{
......
......@@ -44,7 +44,6 @@ public:
const element::Type& get_element_type() const;
const std::string& get_name() const;
const std::string& get_type() const;
bool is_output() const;
const std::shared_ptr<descriptor::TensorView> get_tensor_view() const;
private:
......
......@@ -30,11 +30,7 @@ runtime::gpu::GPU_TensorView::GPU_TensorView(const ngraph::element::Type& elemen
const Shape& shape,
void* memory_pointer)
: runtime::TensorView(std::make_shared<ngraph::descriptor::PrimaryTensorView>(
std::make_shared<ngraph::TensorViewType>(element_type, shape),
"external",
true,
true,
false))
std::make_shared<ngraph::TensorViewType>(element_type, shape), "external"))
, m_custom_memory(false)
{
m_descriptor->set_tensor_view_layout(
......
......@@ -64,8 +64,3 @@ const std::string& runtime::gpu::GPU_TensorViewWrapper::get_type() const
{
return get_element_type().c_type_string();
}
bool runtime::gpu::GPU_TensorViewWrapper::is_output() const
{
return m_tensor_view->get_tensor().is_output();
}
......@@ -44,7 +44,6 @@ public:
const element::Type& get_element_type() const;
const std::string& get_name() const;
const std::string& get_type() const;
bool is_output() const;
private:
std::shared_ptr<descriptor::TensorView> m_tensor_view;
......
......@@ -29,7 +29,7 @@ runtime::HostTensorView::HostTensorView(const ngraph::element::Type& element_typ
void* memory_pointer,
const string& name)
: runtime::TensorView(std::make_shared<ngraph::descriptor::PrimaryTensorView>(
std::make_shared<ngraph::TensorViewType>(element_type, shape), name, true, true, false))
std::make_shared<ngraph::TensorViewType>(element_type, shape), name))
, m_allocated_buffer_pool(nullptr)
, m_aligned_buffer_pool(nullptr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment