Commit 1c78e9f3 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #126 from NervanaSystems/bob/liveness4

add liveness pass
parents 0344dca9 d3f3bb1b
...@@ -24,6 +24,7 @@ set (SRC ...@@ -24,6 +24,7 @@ set (SRC
ngraph/shape.cpp ngraph/shape.cpp
ngraph/pass/assign_tensors.cpp ngraph/pass/assign_tensors.cpp
ngraph/pass/call_pass.cpp ngraph/pass/call_pass.cpp
ngraph/pass/dump_sorted.cpp
ngraph/pass/liveness.cpp ngraph/pass/liveness.cpp
ngraph/pass/manager.cpp ngraph/pass/manager.cpp
ngraph/pass/memory_layout.cpp ngraph/pass/memory_layout.cpp
...@@ -31,6 +32,7 @@ set (SRC ...@@ -31,6 +32,7 @@ set (SRC
ngraph/pass/propagate_types.cpp ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp ngraph/pass/tree_pass.cpp
ngraph/pass/visualize_tree.cpp
ngraph/runtime/call_frame.cpp ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/external_function.cpp ngraph/runtime/eigen/external_function.cpp
ngraph/runtime/eigen/tensor_view.cpp ngraph/runtime/eigen/tensor_view.cpp
......
...@@ -33,3 +33,13 @@ std::shared_ptr<Node> Input::get_node() ...@@ -33,3 +33,13 @@ std::shared_ptr<Node> Input::get_node()
{ {
return m_node->shared_from_this(); return m_node->shared_from_this();
} }
const Tensor& Input::get_tensor() const
{
return m_output.get_tensor();
}
Tensor& Input::get_tensor()
{
return m_output.get_tensor();
}
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <memory> #include <memory>
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
class Node; class Node;
...@@ -47,6 +49,8 @@ namespace ngraph ...@@ -47,6 +49,8 @@ namespace ngraph
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
const Output& get_output() const { return m_output; } const Output& get_output() const { return m_output; }
Output& get_output() { return m_output; } Output& get_output() { return m_output; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected: protected:
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
......
...@@ -35,3 +35,13 @@ std::shared_ptr<Node> Output::get_node() const ...@@ -35,3 +35,13 @@ std::shared_ptr<Node> Output::get_node() const
{ {
return m_node->shared_from_this(); return m_node->shared_from_this();
} }
const Tensor& Output::get_tensor() const
{
return m_tensor_view->get_tensor();
}
Tensor& Output::get_tensor()
{
return m_tensor_view->get_tensor();
}
...@@ -42,6 +42,8 @@ namespace ngraph ...@@ -42,6 +42,8 @@ namespace ngraph
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; } std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input); void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected: protected:
Node* m_node; Node* m_node;
......
...@@ -13,12 +13,30 @@ ...@@ -13,12 +13,30 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view) Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view,
const Node* parent, size_t value_index)
: m_element_type(element_type) : m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view) , m_primary_tensor_view(primary_tensor_view)
, m_is_output{false}
, m_is_input{parent->is_parameter()}
, m_is_persistent{false}
, m_name{parent->get_node_id()+"_"+std::to_string(value_index)}
, m_next_view_id{0}
{ {
} }
std::string Tensor::get_next_view_name()
{
return m_name + "_TV" + std::to_string(m_next_view_id++);
}
std::ostream& ngraph::descriptor::operator<<(std::ostream& out, const Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
return out;
}
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <iostream>
namespace ngraph namespace ngraph
{ {
class Node;
namespace element namespace element
{ {
class Type; class Type;
...@@ -33,14 +36,31 @@ namespace ngraph ...@@ -33,14 +36,31 @@ namespace ngraph
{ {
friend class PrimaryTensorView; friend class PrimaryTensorView;
private:
Tensor(const Tensor&) = delete; Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view); Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view,
const Node* parent, size_t value_index);
std::string get_next_view_name();
public:
bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; }
const std::string& get_name() const { return m_name; }
friend std::ostream& operator<<(std::ostream&, const Tensor&);
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view; PrimaryTensorView* m_primary_tensor_view;
bool m_is_output;
bool m_is_input;
bool m_is_persistent;
std::string m_name;
size_t m_next_view_id;
}; };
} }
} }
...@@ -17,9 +17,12 @@ ...@@ -17,9 +17,12 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "log.hpp"
namespace ngraph namespace ngraph
{ {
class Node;
namespace descriptor namespace descriptor
{ {
class Tensor; class Tensor;
...@@ -41,6 +44,7 @@ namespace ngraph ...@@ -41,6 +44,7 @@ namespace ngraph
virtual ~TensorView() {} virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0; virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0; virtual Tensor& get_tensor() = 0;
const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{ {
...@@ -51,6 +55,7 @@ namespace ngraph ...@@ -51,6 +55,7 @@ namespace ngraph
{ {
return m_tensor_view_layout; return m_tensor_view_layout;
} }
void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout) void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout)
{ {
m_tensor_view_layout = tensor_view_layout; m_tensor_view_layout = tensor_view_layout;
...@@ -59,6 +64,7 @@ namespace ngraph ...@@ -59,6 +64,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type; std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<TensorViewLayout> m_tensor_view_layout; std::shared_ptr<TensorViewLayout> m_tensor_view_layout;
std::string m_name;
}; };
// A PrimaryTensorView owns the tensor. All other views are the result // A PrimaryTensorView owns the tensor. All other views are the result
...@@ -66,10 +72,14 @@ namespace ngraph ...@@ -66,10 +72,14 @@ namespace ngraph
class PrimaryTensorView : public TensorView class PrimaryTensorView : public TensorView
{ {
public: public:
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type) PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const Node* parent, size_t value_index)
: TensorView(tensor_view_type) : TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this) , m_tensor(tensor_view_type->get_element_type(), this, parent, value_index)
{ {
// Set the name in the parent TensorView.
// This can't be done until after the m_tensor is constructed.
m_name = m_tensor.get_next_view_name();
} }
virtual const Tensor& get_tensor() const override; virtual const Tensor& get_tensor() const override;
......
...@@ -57,8 +57,9 @@ void Node::assign_tensors() ...@@ -57,8 +57,9 @@ void Node::assign_tensors()
size_t i = 0; size_t i = 0;
for (auto tvt : tensor_view_types) for (auto tvt : tensor_view_types)
{ {
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt); auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt, this, i);
m_outputs.emplace_back(this, i++, tensor_view_descriptor); m_outputs.emplace_back(this, i, tensor_view_descriptor);
i++;
} }
i = 0; i = 0;
...@@ -68,7 +69,8 @@ void Node::assign_tensors() ...@@ -68,7 +69,8 @@ void Node::assign_tensors()
size_t arg_index = 0; size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs()) for (descriptor::Output& output : arg->get_outputs())
{ {
m_inputs.emplace_back(this, i++, argno, arg_index++, output); m_inputs.emplace_back(this, i, argno, arg_index++, output);
i++;
} }
argno++; argno++;
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_set>
#include <iostream> #include <iostream>
...@@ -29,6 +30,7 @@ namespace ngraph ...@@ -29,6 +30,7 @@ namespace ngraph
{ {
class Input; class Input;
class Output; class Output;
class Tensor;
} }
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
...@@ -51,7 +53,7 @@ namespace ngraph ...@@ -51,7 +53,7 @@ namespace ngraph
virtual ~Node(); virtual ~Node();
public: public:
/// A "one-liner" describing this node. /// The class name, must not contain spaces
virtual std::string description() const = 0; virtual std::string description() const = 0;
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
...@@ -66,9 +68,6 @@ namespace ngraph ...@@ -66,9 +68,6 @@ namespace ngraph
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; }
virtual std::string get_node_id() const; virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
...@@ -104,7 +103,13 @@ namespace ngraph ...@@ -104,7 +103,13 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<descriptor::Input>& get_inputs() { return m_inputs; } std::vector<descriptor::Input>& get_inputs() { return m_inputs; }
const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; }
std::vector<descriptor::Output>& get_outputs() { return m_outputs; } std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; }
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
protected: protected:
Nodes m_arguments; Nodes m_arguments;
......
...@@ -42,7 +42,6 @@ namespace ngraph ...@@ -42,7 +42,6 @@ namespace ngraph
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string get_node_id() const override;
protected: protected:
Function* m_function; Function* m_function;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include "dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
pass::DumpSorted::DumpSorted(const string& output_file)
: m_output_file{output_file}
{
}
bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
{
ofstream out{m_output_file};
if (out)
{
for (const Node* node : nodes)
{
out << node->get_node_id() << "(";
vector<string> inputs;
for (const descriptor::Input& input : node->get_inputs())
{
inputs.push_back(input.get_tensor().get_name());
}
out << join(inputs);
out << ") -> ";
vector<string> outputs;
for (const descriptor::Output& output : node->get_outputs())
{
outputs.push_back(output.get_tensor().get_name());
}
out << join(outputs);
out << "\n";
for (const descriptor::Tensor* tensor : node->liveness_live_list)
{
out << " ";
if (contains(node->liveness_new_list, tensor))
{
out << "N ";
}
else if (contains(node->liveness_free_list, tensor))
{
out << "F ";
}
else
{
out << "L ";
}
out << tensor->get_name() << "\n";
}
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <string>
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class DumpSorted;
}
class Node;
}
class ngraph::pass::DumpSorted : public CallBase
{
public:
DumpSorted(const std::string& output_file);
virtual bool run_on_call_list(std::list<Node*>&) override;
private:
const std::string m_output_file;
};
...@@ -14,111 +14,112 @@ ...@@ -14,111 +14,112 @@
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include <unordered_set>
#include "log.hpp" #include "log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "util.hpp"
#include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops) bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{ {
// list<Node*> live_list; unordered_set<descriptor::Tensor*> currently_live;
// list<Node*> free_list;
// list<Node*> new_list; for(auto it=ops.rbegin(); it!=ops.rend(); it++)
// currently_live = list(); {
Node& exop = **it;
// size_t i = 0; exop.liveness_live_list.clear();
// for (i, exop in enumerate(reversed(ops) exop.liveness_new_list.clear();
// for(auto it=ops.rbegin(); it!=ops.rend(); it++) exop.liveness_free_list.clear();
// { unordered_set<descriptor::Tensor*> input_tensor_decls;
// Node& exop = **it; for (auto input_decl : exop.get_inputs())
// input_tensor_decls = list() {
// for (auto input_decl : exop.get_inputs()) descriptor::Tensor& tensor = input_decl.get_tensor();
// { if (is_temporary(tensor))
// if (is_interesting(input_decl.tensor_decl)) {
// { input_tensor_decls.insert(&tensor);
// input_tensor_decls.append(input_decl.tensor_decl); }
// } }
// }
unordered_set<descriptor::Tensor*> output_tensor_decls;
// output_tensor_decls = list() for (auto output_decl : exop.get_outputs())
// for (output_decl : exop.output_decls) {
// { descriptor::Tensor& tensor = output_decl.get_tensor();
// if (is_interesting(output_decl.tensor_decl)) if (is_temporary(tensor))
// { {
// output_tensor_decls.append(output_decl.tensor_decl); output_tensor_decls.insert(&tensor);
// } }
// } }
// free_tensor_decls = list(); unordered_set<descriptor::Tensor*> free_tensor_decls;
// new_tensor_decls = list(); unordered_set<descriptor::Tensor*> new_tensor_decls;
// for tensor_decl in input_tensor_decls + output_tensor_decls unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
// {
// if tensor_decl not in currently_live for (auto decls : {input_tensor_decls, output_tensor_decls})
// { {
// // this is the last node that value is seen in for (descriptor::Tensor* tensor_decl : decls)
// // delete it at the end of the op {
// currently_live.append(tensor_decl); if (!contains(currently_live, tensor_decl))
// free_tensor_decls.append(tensor_decl); {
// } // this is the last node that value is seen in
// } // delete it at the end of the op
// live_list.insert(0, list(currently_live)) currently_live.insert(tensor_decl);
// for output_decl in output_tensor_decls free_tensor_decls.insert(tensor_decl);
// { }
// if output_decl in currently_live }
// { }
// new_tensor_decls.append(output_decl);
// currently_live.remove(output_decl); exop.liveness_live_list = currently_live;
// } for (descriptor::Tensor* output_decl : output_tensor_decls)
// } {
// free_list.insert(0, free_tensor_decls); if (contains(currently_live, output_decl))
// new_list.insert(0, new_tensor_decls); {
// } new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl);
// // Anything marked as output must remain live for the remainder of the graph }
// // Add outputs to live_list and remove from free_list }
// outputs = list(); exop.liveness_free_list = free_tensor_decls;
// seen = list(); exop.liveness_new_list = new_tensor_decls;
// for i, exop in enumerate(ops) }
// {
// for tensor in live_list[i] // Anything marked as output must remain live for the remainder of the graph
// { // Add outputs to live_list and remove from free_list
// if tensor.is_output and tensor not in outputs unordered_set<descriptor::Tensor*> outputs;
// { unordered_set<descriptor::Tensor*> seen;
// outputs.append(tensor); for (Node* exop : ops)
// } {
// } for (descriptor::Tensor* tensor : exop->liveness_live_list)
// for tensor in outputs {
// { if (tensor->is_output())
// if tensor not in live_list[i] {
// { outputs.insert(tensor);
// live_list[i].append(tensor); }
// } }
// if tensor in free_list[i] for (descriptor::Tensor* tensor : outputs)
// { {
// free_list[i].remove(tensor); exop->liveness_live_list.insert(tensor);
// } exop->liveness_free_list.erase(tensor);
// if tensor in new_list[i]
// { if (contains(exop->liveness_new_list, tensor))
// if tensor in seen {
// { if (contains(seen, tensor))
// new_list[i].remove(tensor); {
// } exop->liveness_new_list.erase(tensor);
// else }
// { else
// seen.append(tensor); {
// } seen.insert(tensor);
// } }
// } }
// exop.liveness_live_list = live_list[i]; }
// exop.liveness_new_list = new_list[i]; }
// exop.liveness_free_list = free_list[i];
// } validate_liveness(ops);
// self.validate_liveness(ops)
return false; return false;
} }
...@@ -140,30 +141,32 @@ void pass::Liveness::check_dependencies( ...@@ -140,30 +141,32 @@ void pass::Liveness::check_dependencies(
} }
} }
// bool pass::Liveness::is_interesting(tensor_decl) bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
// { {
// return return
// tensor_decl.is_persistent == false && tensor.is_persistent() == false
// tensor_decl.is_constant == false && && tensor.is_input() == false
// tensor_decl.is_compile_only == false; ;
// } // && tensor.is_constant() == false
// && tensor.is_compile_only() == false;
// void pass::Liveness::validate_liveness(ops) }
// {
// dead_tensors = set(); void pass::Liveness::validate_liveness(const list<Node*>& ops)
// for i, exop in enumerate(ops) {
// { unordered_set<descriptor::Tensor*> dead_tensors;
// active = set(exop.liveness_live_list); for (const Node* exop : ops)
// active |= set(exop.liveness_new_list); {
// active |= set(exop.liveness_free_list); auto active = exop->liveness_live_list;
// if bool(dead_tensors.intersection(active)) is True active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end());
// { active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
// raise RuntimeError("Liveness: Dead tensors intersect active tensors"); for (const descriptor::Tensor* tensor : active)
// } {
// for tensor in exop.liveness_free_list if (contains(dead_tensors, tensor))
// { {
// dead_tensors.add(tensor); throw runtime_error("Liveness: Dead tensors intersect active tensors");
// } }
// } }
// } dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
}
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "call_pass.hpp" #include "call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,6 +34,6 @@ public: ...@@ -33,6 +34,6 @@ public:
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override; void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
// bool is_interesting(tensor_decl); bool is_temporary(const descriptor::Tensor&);
// void validate_liveness(std::list<Node*> ops); void validate_liveness(const std::list<Node*>& ops);
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include "visualize_tree.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node)
{
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
render();
return false;
}
pass::VisualizeTree::VisualizeTree(const string& file_name)
: m_name{file_name}
{
}
std::string pass::VisualizeTree::add_attributes(const Node* node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string pass::VisualizeTree::get_attributes(const Node* node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
void pass::VisualizeTree::render() const
{
#if GRAPHVIZ_FOUND
auto tmp_file = m_name + ".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph ngraph\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << m_name;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
remove(tmp_file.c_str());
}
#endif
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <sstream>
#include <string>
#include <set>
#include "ngraph/pass/tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class VisualizeTree;
}
class Node;
}
class ngraph::pass::VisualizeTree : public TreeBase
{
public:
VisualizeTree(const std::string& file_name);
bool run_on_tree(std::shared_ptr<Node>) override;
private:
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<const Node*> m_nodes_with_attributes;
};
...@@ -42,10 +42,3 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -42,10 +42,3 @@ void Parameter::assign_function(Function* function, size_t index)
} }
void Parameter::propagate_types() {} void Parameter::propagate_types() {}
std::string ngraph::op::Parameter::get_node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
return ss.str();
}
...@@ -64,5 +64,5 @@ target_link_libraries(unit-test ${CMAKE_DL_LIBS}) ...@@ -64,5 +64,5 @@ target_link_libraries(unit-test ${CMAKE_DL_LIBS})
add_dependencies(unit-test ngraph libgtest eigen) add_dependencies(unit-test ngraph libgtest eigen)
add_custom_target(check add_custom_target(check
COMMAND ${PROJECT_BINARY_DIR}/test/unit-test COMMAND ${PROJECT_BINARY_DIR}/test/unit-test \${ARGS}
DEPENDS unit-test) DEPENDS unit-test)
...@@ -20,14 +20,54 @@ ...@@ -20,14 +20,54 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
#include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph;
namespace ng = ngraph; namespace ng = ngraph;
TEST(liveness, test) TEST(pass, liveness)
{ {
string image = "liveness.png";
string dump_file = "liveness.txt";
pass::Manager pass_manager;
auto visualize = make_shared<pass::VisualizeTree>(image);
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
auto liveness = make_shared<pass::Liveness>();
auto dump_sorted = make_shared<pass::DumpSorted>(dump_file);
pass_manager.register_pass(visualize);
pass_manager.register_pass(topological_sort);
pass_manager.register_pass(propagate_types);
pass_manager.register_pass(assign_tensors);
pass_manager.register_pass(liveness);
pass_manager.register_pass(dump_sorted);
auto graph = make_test_graph();
size_t node_count = get_node_count(graph);
pass_manager.run_passes(graph);
auto sorted = pass_manager.get_sorted_list();
// for (const Node* node : sorted)
// {
// INFO << *node;
// for (const descriptor::Tensor* tensor : node->liveness_live_list)
// {
// INFO << " " << *tensor;
// }
// }
// auto x = ng.variable(axes=[]).named('x'); // auto x = ng.variable(axes=[]).named('x');
// auto y = ng.variable(axes=[]).named('y'); // auto y = ng.variable(axes=[]).named('y');
// auto w1 = ng.variable(axes=[]).named('w1'); // auto w1 = ng.variable(axes=[]).named('w1');
...@@ -48,7 +88,7 @@ TEST(liveness, test) ...@@ -48,7 +88,7 @@ TEST(liveness, test)
// return exc; // return exc;
// lg = LivenessGraph(exc.exop.ops) // lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory() // lg.layout_memory()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment