Commit d3f3bb1b authored by Bob Kimball's avatar Bob Kimball Committed by Robert Kimball

add liveness pass

add naming of tensors and tensor views
add visualization pass
add graph dump pass
parent d93260a9
...@@ -24,6 +24,7 @@ set (SRC ...@@ -24,6 +24,7 @@ set (SRC
ngraph/shape.cpp ngraph/shape.cpp
ngraph/pass/assign_tensors.cpp ngraph/pass/assign_tensors.cpp
ngraph/pass/call_pass.cpp ngraph/pass/call_pass.cpp
ngraph/pass/dump_sorted.cpp
ngraph/pass/liveness.cpp ngraph/pass/liveness.cpp
ngraph/pass/manager.cpp ngraph/pass/manager.cpp
ngraph/pass/memory_layout.cpp ngraph/pass/memory_layout.cpp
...@@ -31,6 +32,7 @@ set (SRC ...@@ -31,6 +32,7 @@ set (SRC
ngraph/pass/propagate_types.cpp ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp ngraph/pass/tree_pass.cpp
ngraph/pass/visualize_tree.cpp
ngraph/runtime/call_frame.cpp ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/external_function.cpp ngraph/runtime/eigen/external_function.cpp
ngraph/runtime/eigen/tensor_view.cpp ngraph/runtime/eigen/tensor_view.cpp
......
...@@ -33,3 +33,13 @@ std::shared_ptr<Node> Input::get_node() ...@@ -33,3 +33,13 @@ std::shared_ptr<Node> Input::get_node()
{ {
return m_node->shared_from_this(); return m_node->shared_from_this();
} }
const Tensor& Input::get_tensor() const
{
return m_output.get_tensor();
}
Tensor& Input::get_tensor()
{
return m_output.get_tensor();
}
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <memory> #include <memory>
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
class Node; class Node;
...@@ -47,6 +49,8 @@ namespace ngraph ...@@ -47,6 +49,8 @@ namespace ngraph
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
const Output& get_output() const { return m_output; } const Output& get_output() const { return m_output; }
Output& get_output() { return m_output; } Output& get_output() { return m_output; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected: protected:
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
......
...@@ -35,3 +35,13 @@ std::shared_ptr<Node> Output::get_node() const ...@@ -35,3 +35,13 @@ std::shared_ptr<Node> Output::get_node() const
{ {
return m_node->shared_from_this(); return m_node->shared_from_this();
} }
const Tensor& Output::get_tensor() const
{
return m_tensor_view->get_tensor();
}
Tensor& Output::get_tensor()
{
return m_tensor_view->get_tensor();
}
...@@ -42,6 +42,8 @@ namespace ngraph ...@@ -42,6 +42,8 @@ namespace ngraph
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; } std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input); void add_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
protected: protected:
Node* m_node; Node* m_node;
......
...@@ -13,12 +13,30 @@ ...@@ -13,12 +13,30 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view) Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view,
const Node* parent, size_t value_index)
: m_element_type(element_type) : m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view) , m_primary_tensor_view(primary_tensor_view)
, m_is_output{false}
, m_is_input{parent->is_parameter()}
, m_is_persistent{false}
, m_name{parent->get_node_id()+"_"+std::to_string(value_index)}
, m_next_view_id{0}
{ {
} }
std::string Tensor::get_next_view_name()
{
return m_name + "_TV" + std::to_string(m_next_view_id++);
}
std::ostream& ngraph::descriptor::operator<<(std::ostream& out, const Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
return out;
}
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <iostream>
namespace ngraph namespace ngraph
{ {
class Node;
namespace element namespace element
{ {
class Type; class Type;
...@@ -33,14 +36,31 @@ namespace ngraph ...@@ -33,14 +36,31 @@ namespace ngraph
{ {
friend class PrimaryTensorView; friend class PrimaryTensorView;
private:
Tensor(const Tensor&) = delete; Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete; Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view); Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view,
const Node* parent, size_t value_index);
std::string get_next_view_name();
public:
bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; }
const std::string& get_name() const { return m_name; }
friend std::ostream& operator<<(std::ostream&, const Tensor&);
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view; PrimaryTensorView* m_primary_tensor_view;
bool m_is_output;
bool m_is_input;
bool m_is_persistent;
std::string m_name;
size_t m_next_view_id;
}; };
} }
} }
...@@ -17,9 +17,12 @@ ...@@ -17,9 +17,12 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "log.hpp"
namespace ngraph namespace ngraph
{ {
class Node;
namespace descriptor namespace descriptor
{ {
class Tensor; class Tensor;
...@@ -41,6 +44,7 @@ namespace ngraph ...@@ -41,6 +44,7 @@ namespace ngraph
virtual ~TensorView() {} virtual ~TensorView() {}
virtual const Tensor& get_tensor() const = 0; virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0; virtual Tensor& get_tensor() = 0;
const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{ {
...@@ -51,6 +55,7 @@ namespace ngraph ...@@ -51,6 +55,7 @@ namespace ngraph
{ {
return m_tensor_view_layout; return m_tensor_view_layout;
} }
void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout) void set_tensor_view_layout(const std::shared_ptr<TensorViewLayout>& tensor_view_layout)
{ {
m_tensor_view_layout = tensor_view_layout; m_tensor_view_layout = tensor_view_layout;
...@@ -59,6 +64,7 @@ namespace ngraph ...@@ -59,6 +64,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<const TensorViewType> m_tensor_view_type; std::shared_ptr<const TensorViewType> m_tensor_view_type;
std::shared_ptr<TensorViewLayout> m_tensor_view_layout; std::shared_ptr<TensorViewLayout> m_tensor_view_layout;
std::string m_name;
}; };
// A PrimaryTensorView owns the tensor. All other views are the result // A PrimaryTensorView owns the tensor. All other views are the result
...@@ -66,10 +72,14 @@ namespace ngraph ...@@ -66,10 +72,14 @@ namespace ngraph
class PrimaryTensorView : public TensorView class PrimaryTensorView : public TensorView
{ {
public: public:
PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type) PrimaryTensorView(const std::shared_ptr<const TensorViewType>& tensor_view_type,
const Node* parent, size_t value_index)
: TensorView(tensor_view_type) : TensorView(tensor_view_type)
, m_tensor(tensor_view_type->get_element_type(), this) , m_tensor(tensor_view_type->get_element_type(), this, parent, value_index)
{ {
// Set the name in the parent TensorView.
// This can't be done until after the m_tensor is constructed.
m_name = m_tensor.get_next_view_name();
} }
virtual const Tensor& get_tensor() const override; virtual const Tensor& get_tensor() const override;
......
...@@ -57,8 +57,9 @@ void Node::assign_tensors() ...@@ -57,8 +57,9 @@ void Node::assign_tensors()
size_t i = 0; size_t i = 0;
for (auto tvt : tensor_view_types) for (auto tvt : tensor_view_types)
{ {
auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt); auto tensor_view_descriptor = make_shared<descriptor::PrimaryTensorView>(tvt, this, i);
m_outputs.emplace_back(this, i++, tensor_view_descriptor); m_outputs.emplace_back(this, i, tensor_view_descriptor);
i++;
} }
i = 0; i = 0;
...@@ -68,7 +69,8 @@ void Node::assign_tensors() ...@@ -68,7 +69,8 @@ void Node::assign_tensors()
size_t arg_index = 0; size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs()) for (descriptor::Output& output : arg->get_outputs())
{ {
m_inputs.emplace_back(this, i++, argno, arg_index++, output); m_inputs.emplace_back(this, i, argno, arg_index++, output);
i++;
} }
argno++; argno++;
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_set>
#include <iostream> #include <iostream>
...@@ -29,6 +30,7 @@ namespace ngraph ...@@ -29,6 +30,7 @@ namespace ngraph
{ {
class Input; class Input;
class Output; class Output;
class Tensor;
} }
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
...@@ -51,7 +53,7 @@ namespace ngraph ...@@ -51,7 +53,7 @@ namespace ngraph
virtual ~Node(); virtual ~Node();
public: public:
/// A "one-liner" describing this node. /// The class name, must not contain spaces
virtual std::string description() const = 0; virtual std::string description() const = 0;
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
...@@ -66,9 +68,6 @@ namespace ngraph ...@@ -66,9 +68,6 @@ namespace ngraph
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; }
virtual std::string get_node_id() const; virtual std::string get_node_id() const;
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
...@@ -104,7 +103,13 @@ namespace ngraph ...@@ -104,7 +103,13 @@ namespace ngraph
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<descriptor::Input>& get_inputs() { return m_inputs; } std::vector<descriptor::Input>& get_inputs() { return m_inputs; }
const std::vector<descriptor::Input>& get_inputs() const { return m_inputs; }
std::vector<descriptor::Output>& get_outputs() { return m_outputs; } std::vector<descriptor::Output>& get_outputs() { return m_outputs; }
const std::vector<descriptor::Output>& get_outputs() const { return m_outputs; }
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
protected: protected:
Nodes m_arguments; Nodes m_arguments;
......
...@@ -42,7 +42,6 @@ namespace ngraph ...@@ -42,7 +42,6 @@ namespace ngraph
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string get_node_id() const override;
protected: protected:
Function* m_function; Function* m_function;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include "dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
pass::DumpSorted::DumpSorted(const string& output_file)
: m_output_file{output_file}
{
}
bool pass::DumpSorted::run_on_call_list(list<Node*>& nodes)
{
ofstream out{m_output_file};
if (out)
{
for (const Node* node : nodes)
{
out << node->get_node_id() << "(";
vector<string> inputs;
for (const descriptor::Input& input : node->get_inputs())
{
inputs.push_back(input.get_tensor().get_name());
}
out << join(inputs);
out << ") -> ";
vector<string> outputs;
for (const descriptor::Output& output : node->get_outputs())
{
outputs.push_back(output.get_tensor().get_name());
}
out << join(outputs);
out << "\n";
for (const descriptor::Tensor* tensor : node->liveness_live_list)
{
out << " ";
if (contains(node->liveness_new_list, tensor))
{
out << "N ";
}
else if (contains(node->liveness_free_list, tensor))
{
out << "F ";
}
else
{
out << "L ";
}
out << tensor->get_name() << "\n";
}
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <string>
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class DumpSorted;
}
class Node;
}
class ngraph::pass::DumpSorted : public CallBase
{
public:
DumpSorted(const std::string& output_file);
virtual bool run_on_call_list(std::list<Node*>&) override;
private:
const std::string m_output_file;
};
...@@ -14,111 +14,112 @@ ...@@ -14,111 +14,112 @@
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include <unordered_set>
#include "log.hpp" #include "log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "util.hpp"
#include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops) bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{ {
// list<Node*> live_list; unordered_set<descriptor::Tensor*> currently_live;
// list<Node*> free_list;
// list<Node*> new_list; for(auto it=ops.rbegin(); it!=ops.rend(); it++)
// currently_live = list(); {
Node& exop = **it;
// size_t i = 0; exop.liveness_live_list.clear();
// for (i, exop in enumerate(reversed(ops) exop.liveness_new_list.clear();
// for(auto it=ops.rbegin(); it!=ops.rend(); it++) exop.liveness_free_list.clear();
// { unordered_set<descriptor::Tensor*> input_tensor_decls;
// Node& exop = **it; for (auto input_decl : exop.get_inputs())
// input_tensor_decls = list() {
// for (auto input_decl : exop.get_inputs()) descriptor::Tensor& tensor = input_decl.get_tensor();
// { if (is_temporary(tensor))
// if (is_interesting(input_decl.tensor_decl)) {
// { input_tensor_decls.insert(&tensor);
// input_tensor_decls.append(input_decl.tensor_decl); }
// } }
// }
unordered_set<descriptor::Tensor*> output_tensor_decls;
// output_tensor_decls = list() for (auto output_decl : exop.get_outputs())
// for (output_decl : exop.output_decls) {
// { descriptor::Tensor& tensor = output_decl.get_tensor();
// if (is_interesting(output_decl.tensor_decl)) if (is_temporary(tensor))
// { {
// output_tensor_decls.append(output_decl.tensor_decl); output_tensor_decls.insert(&tensor);
// } }
// } }
// free_tensor_decls = list(); unordered_set<descriptor::Tensor*> free_tensor_decls;
// new_tensor_decls = list(); unordered_set<descriptor::Tensor*> new_tensor_decls;
// for tensor_decl in input_tensor_decls + output_tensor_decls unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
// {
// if tensor_decl not in currently_live for (auto decls : {input_tensor_decls, output_tensor_decls})
// { {
// // this is the last node that value is seen in for (descriptor::Tensor* tensor_decl : decls)
// // delete it at the end of the op {
// currently_live.append(tensor_decl); if (!contains(currently_live, tensor_decl))
// free_tensor_decls.append(tensor_decl); {
// } // this is the last node that value is seen in
// } // delete it at the end of the op
// live_list.insert(0, list(currently_live)) currently_live.insert(tensor_decl);
// for output_decl in output_tensor_decls free_tensor_decls.insert(tensor_decl);
// { }
// if output_decl in currently_live }
// { }
// new_tensor_decls.append(output_decl);
// currently_live.remove(output_decl); exop.liveness_live_list = currently_live;
// } for (descriptor::Tensor* output_decl : output_tensor_decls)
// } {
// free_list.insert(0, free_tensor_decls); if (contains(currently_live, output_decl))
// new_list.insert(0, new_tensor_decls); {
// } new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl);
// // Anything marked as output must remain live for the remainder of the graph }
// // Add outputs to live_list and remove from free_list }
// outputs = list(); exop.liveness_free_list = free_tensor_decls;
// seen = list(); exop.liveness_new_list = new_tensor_decls;
// for i, exop in enumerate(ops) }
// {
// for tensor in live_list[i] // Anything marked as output must remain live for the remainder of the graph
// { // Add outputs to live_list and remove from free_list
// if tensor.is_output and tensor not in outputs unordered_set<descriptor::Tensor*> outputs;
// { unordered_set<descriptor::Tensor*> seen;
// outputs.append(tensor); for (Node* exop : ops)
// } {
// } for (descriptor::Tensor* tensor : exop->liveness_live_list)
// for tensor in outputs {
// { if (tensor->is_output())
// if tensor not in live_list[i] {
// { outputs.insert(tensor);
// live_list[i].append(tensor); }
// } }
// if tensor in free_list[i] for (descriptor::Tensor* tensor : outputs)
// { {
// free_list[i].remove(tensor); exop->liveness_live_list.insert(tensor);
// } exop->liveness_free_list.erase(tensor);
// if tensor in new_list[i]
// { if (contains(exop->liveness_new_list, tensor))
// if tensor in seen {
// { if (contains(seen, tensor))
// new_list[i].remove(tensor); {
// } exop->liveness_new_list.erase(tensor);
// else }
// { else
// seen.append(tensor); {
// } seen.insert(tensor);
// } }
// } }
// exop.liveness_live_list = live_list[i]; }
// exop.liveness_new_list = new_list[i]; }
// exop.liveness_free_list = free_list[i];
// } validate_liveness(ops);
// self.validate_liveness(ops)
return false; return false;
} }
...@@ -140,30 +141,32 @@ void pass::Liveness::check_dependencies( ...@@ -140,30 +141,32 @@ void pass::Liveness::check_dependencies(
} }
} }
// bool pass::Liveness::is_interesting(tensor_decl) bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
// { {
// return return
// tensor_decl.is_persistent == false && tensor.is_persistent() == false
// tensor_decl.is_constant == false && && tensor.is_input() == false
// tensor_decl.is_compile_only == false; ;
// } // && tensor.is_constant() == false
// && tensor.is_compile_only() == false;
// void pass::Liveness::validate_liveness(ops) }
// {
// dead_tensors = set(); void pass::Liveness::validate_liveness(const list<Node*>& ops)
// for i, exop in enumerate(ops) {
// { unordered_set<descriptor::Tensor*> dead_tensors;
// active = set(exop.liveness_live_list); for (const Node* exop : ops)
// active |= set(exop.liveness_new_list); {
// active |= set(exop.liveness_free_list); auto active = exop->liveness_live_list;
// if bool(dead_tensors.intersection(active)) is True active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end());
// { active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
// raise RuntimeError("Liveness: Dead tensors intersect active tensors"); for (const descriptor::Tensor* tensor : active)
// } {
// for tensor in exop.liveness_free_list if (contains(dead_tensors, tensor))
// { {
// dead_tensors.add(tensor); throw runtime_error("Liveness: Dead tensors intersect active tensors");
// } }
// } }
// } dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
}
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "call_pass.hpp" #include "call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,6 +34,6 @@ public: ...@@ -33,6 +34,6 @@ public:
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override; void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
// bool is_interesting(tensor_decl); bool is_temporary(const descriptor::Tensor&);
// void validate_liveness(std::list<Node*> ops); void validate_liveness(const std::list<Node*>& ops);
}; };
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <fstream>
#include "visualize_tree.hpp"
#include "ngraph/node.hpp"
#include "util.hpp"
using namespace ngraph;
using namespace std;
bool pass::VisualizeTree::run_on_tree(std::shared_ptr<Node> base_node)
{
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(base_node, [&](Node* node)
{
for (auto arg : node->get_arguments())
{
m_ss << add_attributes(arg.get());
m_ss << add_attributes(node);
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id();
m_ss << ";\n";
}
});
render();
return false;
}
pass::VisualizeTree::VisualizeTree(const string& file_name)
: m_name{file_name}
{
}
std::string pass::VisualizeTree::add_attributes(const Node* node)
{
string rc;
if (!contains(m_nodes_with_attributes, node))
{
m_nodes_with_attributes.insert(node);
rc = get_attributes(node);
}
return rc;
}
std::string pass::VisualizeTree::get_attributes(const Node* node)
{
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_node_id() << " [shape=box color=blue]\n";
}
else
{
ss << " " << node->get_node_id() << " [shape=ellipse color=black]\n";
}
return ss.str();
}
void pass::VisualizeTree::render() const
{
#if GRAPHVIZ_FOUND
auto tmp_file = m_name + ".tmp";
ofstream out(tmp_file);
if (out)
{
out << "digraph ngraph\n{\n";
out << m_ss.str();
out << "}\n";
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << m_name;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
remove(tmp_file.c_str());
}
#endif
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <sstream>
#include <string>
#include <set>
#include "ngraph/pass/tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class VisualizeTree;
}
class Node;
}
class ngraph::pass::VisualizeTree : public TreeBase
{
public:
VisualizeTree(const std::string& file_name);
bool run_on_tree(std::shared_ptr<Node>) override;
private:
std::string add_attributes(const Node* node);
std::string get_attributes(const Node* node);
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<const Node*> m_nodes_with_attributes;
};
...@@ -42,10 +42,3 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -42,10 +42,3 @@ void Parameter::assign_function(Function* function, size_t index)
} }
void Parameter::propagate_types() {} void Parameter::propagate_types() {}
std::string ngraph::op::Parameter::get_node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
return ss.str();
}
...@@ -64,5 +64,5 @@ target_link_libraries(unit-test ${CMAKE_DL_LIBS}) ...@@ -64,5 +64,5 @@ target_link_libraries(unit-test ${CMAKE_DL_LIBS})
add_dependencies(unit-test ngraph libgtest eigen) add_dependencies(unit-test ngraph libgtest eigen)
add_custom_target(check add_custom_target(check
COMMAND ${PROJECT_BINARY_DIR}/test/unit-test COMMAND ${PROJECT_BINARY_DIR}/test/unit-test \${ARGS}
DEPENDS unit-test) DEPENDS unit-test)
...@@ -20,14 +20,54 @@ ...@@ -20,14 +20,54 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
#include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph;
namespace ng = ngraph; namespace ng = ngraph;
TEST(liveness, test) TEST(pass, liveness)
{ {
string image = "liveness.png";
string dump_file = "liveness.txt";
pass::Manager pass_manager;
auto visualize = make_shared<pass::VisualizeTree>(image);
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
auto liveness = make_shared<pass::Liveness>();
auto dump_sorted = make_shared<pass::DumpSorted>(dump_file);
pass_manager.register_pass(visualize);
pass_manager.register_pass(topological_sort);
pass_manager.register_pass(propagate_types);
pass_manager.register_pass(assign_tensors);
pass_manager.register_pass(liveness);
pass_manager.register_pass(dump_sorted);
auto graph = make_test_graph();
size_t node_count = get_node_count(graph);
pass_manager.run_passes(graph);
auto sorted = pass_manager.get_sorted_list();
// for (const Node* node : sorted)
// {
// INFO << *node;
// for (const descriptor::Tensor* tensor : node->liveness_live_list)
// {
// INFO << " " << *tensor;
// }
// }
// auto x = ng.variable(axes=[]).named('x'); // auto x = ng.variable(axes=[]).named('x');
// auto y = ng.variable(axes=[]).named('y'); // auto y = ng.variable(axes=[]).named('y');
// auto w1 = ng.variable(axes=[]).named('w1'); // auto w1 = ng.variable(axes=[]).named('w1');
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment