Commit f828a116 authored by Bob Kimball's avatar Bob Kimball

Add mechanism for pass dependency checking

Add memory layout and liveness passes.
parent 8a6c08df
......@@ -18,8 +18,12 @@ set (SRC
ngraph/descriptor/tensor_view.cpp
ngraph/descriptor/tensor.cpp
ngraph/node.cpp
ngraph/shape.cpp
ngraph/pass/assign_tensors.cpp
ngraph/pass/call_pass.cpp
ngraph/pass/liveness.cpp
ngraph/pass/manager.cpp
ngraph/pass/memory_layout.cpp
ngraph/pass/pass.cpp
ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp
......
......@@ -53,7 +53,7 @@ namespace ngraph
size_t m_bitwidth;
bool m_is_float;
bool m_is_signed;
const std::string& m_cname;
std::string m_cname;
};
std::ostream& operator<<(std::ostream& out, const ngraph::element::Type& obj);
......
......@@ -106,8 +106,8 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
std::vector<std::shared_ptr<descriptor::Input>> get_inputs() { return m_inputs; }
std::vector<std::shared_ptr<descriptor::Output>> get_outputs() { return m_outputs; }
const std::vector<std::shared_ptr<descriptor::Input>>& get_inputs() { return m_inputs; }
const std::vector<std::shared_ptr<descriptor::Output>>& get_outputs() { return m_outputs; }
protected:
Nodes m_arguments;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "assign_tensors.hpp"
#include <exception>
#include <sstream>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_list(std::list<Node*>& node_list)
{
for (Node* node : node_list)
{
try
{
node->assign_tensors();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
void pass::AssignTensors::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<PropagateTypes>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Depencency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class AssignTensors;
}
class Node;
}
class ngraph::pass::AssignTensors : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
};
......@@ -15,6 +15,8 @@
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
......@@ -32,7 +34,9 @@ class ngraph::pass::CallBase : public Base
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>&) = 0;
virtual bool run_on_call_list(std::list<Node*>) = 0;
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
using namespace std;
using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{
// list<Node*> live_list;
// list<Node*> free_list;
// list<Node*> new_list;
// currently_live = list();
// size_t i = 0;
// for (i, exop in enumerate(reversed(ops)
// for(auto it=ops.rbegin(); it!=ops.rend(); it++)
// {
// Node& exop = **it;
// input_tensor_decls = list()
// for (auto input_decl : exop.get_inputs())
// {
// if (is_interesting(input_decl.tensor_decl))
// {
// input_tensor_decls.append(input_decl.tensor_decl);
// }
// }
// output_tensor_decls = list()
// for (output_decl : exop.output_decls)
// {
// if (is_interesting(output_decl.tensor_decl))
// {
// output_tensor_decls.append(output_decl.tensor_decl);
// }
// }
// free_tensor_decls = list();
// new_tensor_decls = list();
// for tensor_decl in input_tensor_decls + output_tensor_decls
// {
// if tensor_decl not in currently_live
// {
// // this is the last node that value is seen in
// // delete it at the end of the op
// currently_live.append(tensor_decl);
// free_tensor_decls.append(tensor_decl);
// }
// }
// live_list.insert(0, list(currently_live))
// for output_decl in output_tensor_decls
// {
// if output_decl in currently_live
// {
// new_tensor_decls.append(output_decl);
// currently_live.remove(output_decl);
// }
// }
// free_list.insert(0, free_tensor_decls);
// new_list.insert(0, new_tensor_decls);
// }
// // Anything marked as output must remain live for the remainder of the graph
// // Add outputs to live_list and remove from free_list
// outputs = list();
// seen = list();
// for i, exop in enumerate(ops)
// {
// for tensor in live_list[i]
// {
// if tensor.is_output and tensor not in outputs
// {
// outputs.append(tensor);
// }
// }
// for tensor in outputs
// {
// if tensor not in live_list[i]
// {
// live_list[i].append(tensor);
// }
// if tensor in free_list[i]
// {
// free_list[i].remove(tensor);
// }
// if tensor in new_list[i]
// {
// if tensor in seen
// {
// new_list[i].remove(tensor);
// }
// else
// {
// seen.append(tensor);
// }
// }
// }
// exop.liveness_live_list = live_list[i];
// exop.liveness_new_list = new_list[i];
// exop.liveness_free_list = free_list[i];
// }
// self.validate_liveness(ops)
return false;
}
void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<AssignTensors>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Depencency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
// bool pass::Liveness::is_interesting(tensor_decl)
// {
// return
// tensor_decl.is_persistent == false &&
// tensor_decl.is_constant == false &&
// tensor_decl.is_compile_only == false;
// }
// void pass::Liveness::validate_liveness(ops)
// {
// dead_tensors = set();
// for i, exop in enumerate(ops)
// {
// active = set(exop.liveness_live_list);
// active |= set(exop.liveness_new_list);
// active |= set(exop.liveness_free_list);
// if bool(dead_tensors.intersection(active)) is True
// {
// raise RuntimeError("Liveness: Dead tensors intersect active tensors");
// }
// for tensor in exop.liveness_free_list
// {
// dead_tensors.add(tensor);
// }
// }
// }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class Liveness;
}
class Node;
}
class ngraph::pass::Liveness : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
// bool is_interesting(tensor_decl);
// void validate_liveness(std::list<Node*> ops);
};
......@@ -15,9 +15,9 @@
#include <iostream>
#include <memory>
#include "log.hpp"
#include "manager.hpp"
#include "ngraph/node.hpp"
#include "log.hpp"
using namespace std;
......@@ -39,6 +39,7 @@ void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p);
}
......@@ -48,6 +49,7 @@ void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_call_passes);
m_call_passes.push_back(p);
}
......
......@@ -45,7 +45,7 @@ public:
const std::list<Node*>& get_sorted_list() const;
private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
{
for (Node* node : node_list)
{
}
return false;
}
void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<Liveness>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Depencency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size}
, m_state{state}
{
}
pass::MemoryManager::MemoryManager(size_t alignment)
: m_alignment{alignment}
, m_scheme{allocation_scheme::BEST_FIT}
, m_max_allocated{0}
{
// assert(m_base_offset % m_alignment == 0);
m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE);
}
size_t pass::MemoryManager::allocate(size_t size)
{
INFO;
size_t rc;
switch(m_scheme)
{
case allocation_scheme::FIRST_FIT:
INFO;
rc = first_fit(size);
break;
case allocation_scheme::BEST_FIT:
INFO;
rc = best_fit(size);
break;
}
return rc;
}
size_t pass::MemoryManager::best_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
bool found = false;
size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end();
size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
size_t delta = it->m_size - size;
INFO << "delta " << delta << ", min_delta " << min_delta;
if (delta < min_delta)
{
min_delta = delta;
best_fit = it;
best_offset = offset;
}
}
offset += it->m_size;
}
if (best_fit == m_node_list.end())
{
throw bad_alloc();
}
else if (!found)
{
if (min_delta == 0)
{
// exact fit
best_fit->m_state = block_state::ALLOCATED;
}
else
{
m_node_list.insert(best_fit, node{size, block_state::ALLOCATED});
best_fit->m_size -= size;
}
}
m_max_allocated = std::max(m_max_allocated, best_offset + size);
return best_offset;
}
size_t pass::MemoryManager::first_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
if (it->m_size > size)
{
m_node_list.insert(it, node{size, block_state::ALLOCATED});
it->m_size -= size;
}
else
{
// exact fit
it->m_state = block_state::ALLOCATED;
}
found = true;
break;
}
offset += it->m_size;
}
if (!found)
{
throw bad_alloc();
}
m_max_allocated = std::max(m_max_allocated, offset + size);
return offset;
}
void pass::MemoryManager::free(size_t offset)
{
size_t search_offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (offset == search_offset)
{
list<node>::iterator it_next = std::next(it);
if (it == m_node_list.begin())
{
// free the first node in the list
it->m_state = block_state::FREE;
}
else
{
// node has predecessor
list<node>::iterator it_prev = std::prev(it);
if (it_prev->m_state == block_state::FREE)
{
it->m_size += it_prev->m_size;
m_node_list.erase(it_prev);
}
}
if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE)
{
// join this node with next
it->m_size += it_next->m_size;
m_node_list.erase(it_next);
}
it->m_state = block_state::FREE;
found = true;
break;
}
search_offset += it->m_size;
}
if (!found)
{
throw runtime_error("bad free");
}
}
void pass::MemoryManager::dump(std::ostream& out)
{
for (const node& n : m_node_list)
{
out << "size=" << n.m_size << ", ";
out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED");
out << "\n";
}
}
size_t pass::MemoryManager::align(size_t size, size_t alignment)
{
if (size == 0)
{
size = alignment;
}
else
{
auto remainder = size % alignment;
INFO << "size " << size << ", remainder " << remainder << ", alignment " << alignment;
if (remainder > 0)
{
size += (alignment - remainder);
}
INFO << "final size " << size;
}
return size;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <limits>
#include <list>
#include <sstream>
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class MemoryLayout;
class MemoryNode;
class MemoryManager;
}
class Node;
}
class ngraph::pass::MemoryLayout : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
};
class ngraph::pass::MemoryManager
{
public:
enum class block_state
{
FREE,
ALLOCATED
};
enum class allocation_scheme
{
FIRST_FIT,
BEST_FIT
};
class node
{
public:
node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; }
size_t m_size;
block_state m_state;
};
MemoryManager(size_t alignment=1);
// memory_manager& alignment(size_t a);
size_t allocate(size_t size);
void free(size_t offset);
void dump(std::ostream&);
static size_t align(size_t x, size_t alignment);
std::list<node>::iterator begin() { return m_node_list.begin(); }
std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; }
private:
size_t first_fit(size_t size);
size_t best_fit(size_t size);
std::list<node> m_node_list;
size_t m_alignment;
allocation_scheme m_scheme;
size_t m_max_allocated;
};
......@@ -14,13 +14,13 @@
#include <sstream>
#include "propagate_types.hpp"
#include "ngraph/ngraph.hpp"
#include "propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_list(std::list<Node*> node_list)
bool pass::PropagateTypes::run_on_call_list(std::list<Node*>& node_list)
{
for (Node* node : node_list)
{
......
......@@ -28,7 +28,7 @@ namespace ngraph
class ngraph::pass::PropagateTypes : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>) override;
virtual bool run_on_call_list(std::list<Node*>&) override;
private:
};
......@@ -25,7 +25,7 @@ using namespace std;
bool ngraph::pass::TopologicalSort::run_on_tree(std::shared_ptr<Node> p)
{
deque<Node*> independent_nodes;
deque<Node*> independent_nodes;
unordered_map<Node*, size_t> node_depencency_count;
traverse_nodes(p, [&](Node* node) {
......
......@@ -32,7 +32,6 @@ class ngraph::pass::TopologicalSort : public TreeBase
{
public:
TopologicalSort() {}
bool run_on_tree(std::shared_ptr<Node>) override;
bool call_graph_produced() const override { return true; }
......
......@@ -14,8 +14,9 @@
#pragma once
#include <memory>
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
......@@ -33,13 +34,13 @@ class ngraph::pass::TreeBase : public Base
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; }
virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {}
private:
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "shape.hpp"
#include "util.hpp"
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Shape& obj)
{
out << "{" << join(obj.m_sizes, ", ") << "}";
return out;
}
......@@ -13,7 +13,9 @@
// ----------------------------------------------------------------------------
#pragma once
#include <cstddef>
#include <cstdio>
#include <iostream>
#include <vector>
namespace ngraph
......@@ -39,6 +41,7 @@ namespace ngraph
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
friend std::ostream& operator<<(std::ostream&, const Shape&);
protected:
std::vector<size_t> m_sizes;
......
......@@ -15,6 +15,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -32,7 +32,6 @@ ngraph::element::Type::Type(size_t bitwidth,
, m_is_signed{is_signed}
, m_cname{cname}
{
INFO << m_cname;
assert(m_bitwidth % 8 == 0);
}
......@@ -54,6 +53,6 @@ size_t ngraph::element::Type::size() const
std::ostream& ngraph::element::operator<<(std::ostream& out, const ngraph::element::Type& obj)
{
// out << "ElementType(" << obj.c_type_string() << ")";
out << obj.m_cname;
return out;
}
\ No newline at end of file
......@@ -14,7 +14,8 @@
#include <memory>
#include <ngraph/ngraph.hpp>
#include "ngraph/ngraph.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
......@@ -68,7 +69,7 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ValueType& obj)
std::ostream& ngraph::operator<<(std::ostream& out, const TensorViewType& obj)
{
out << "TensorViewType(" << obj.m_element_type << ")";
out << "TensorViewType(" << obj.m_element_type << ", " << obj.m_shape << ")";
return out;
}
......
......@@ -28,7 +28,9 @@ set (SRC
input_output_assign.cpp
main.cpp
op.cpp
pass_liveness.cpp
pass_manager.cpp
pass_memory_layout.cpp
tensor.cpp
test_tools.cpp
topological_sort.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/ngraph.hpp"
#include "test_tools.hpp"
using namespace std;
namespace ng = ngraph;
TEST(liveness, test)
{
// auto x = ng.variable(axes=[]).named('x');
// auto y = ng.variable(axes=[]).named('y');
// auto w1 = ng.variable(axes=[]).named('w1');
// auto w2 = ng.variable(axes=[]).named('w2');
// auto x2 = x * w1;
// auto x3 = (x2 * w2).named('result');
// auto cost = x3 - y;
// auto dw1 = ng.deriv(cost, w1);
// auto dw2 = ng.deriv(cost, w2);
// auto upd1 = ng.assign(w1, w1 + dw1);
// auto upd2 = ng.assign(w2, w2 + dw2);
// auto seq_stuff = ng.sequential([upd1, upd2, x3]);
// auto exc = ex.executor(seq_stuff);
// return exc;
// lg = LivenessGraph(exc.exop.ops)
// lg.layout_memory()
// for i, node in enumerate(lg.liveness_nodes):
// print i, node
// for node in lg.liveness_nodes:
// for var1 in node.live_list:
// assert var1.buffer_pool_offset is not None
// for var2 in node.live_list:
// if var1 != var2:
// if var1.buffer_pool_offset < var2.buffer_pool_offset:
// assert var1.buffer_pool_offset + var1.size <= var2.buffer_pool_offset
// else:
// assert var2.buffer_pool_offset + var2.size <= var1.buffer_pool_offset
// // for o in egraph.computations:
// // print o.values
// print("max memory {}".format(lg.memory_footprint()))
// print("worst case memory {}".format(lg.worst_case_memory_usage()))
// print("memory efficiency {}".format(lg.memory_efficiency()))
// // // print lg.liveness_json()
}
......@@ -19,28 +19,42 @@
#include "gtest/gtest.h"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "test_tools.hpp"
using namespace ngraph;
using namespace std;
// TEST(pass_manager, add)
// {
// pass::Manager pass_manager;
// auto topological_sort = make_shared<pass::TopologicalSort>();
// auto propagate_types = make_shared<pass::PropagateTypes>();
// pass_manager.register_pass(topological_sort);
// pass_manager.register_pass(propagate_types);
// auto graph = make_test_graph();
// size_t node_count = get_node_count(graph);
// pass_manager.run_passes(graph);
// auto sorted = pass_manager.get_sorted_list();
// EXPECT_EQ(node_count, sorted.size());
// EXPECT_TRUE(validate_list(sorted));
// }
TEST(pass_manager, add)
{
pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
pass_manager.register_pass(topological_sort);
pass_manager.register_pass(propagate_types);
pass_manager.register_pass(assign_tensors);
auto graph = make_test_graph();
size_t node_count = get_node_count(graph);
pass_manager.run_passes(graph);
auto sorted = pass_manager.get_sorted_list();
EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted));
}
TEST(pass_manager, dependency)
{
pass::Manager pass_manager;
auto topological_sort = make_shared<pass::TopologicalSort>();
auto propagate_types = make_shared<pass::PropagateTypes>();
auto assign_tensors = make_shared<pass::AssignTensors>();
pass_manager.register_pass(topological_sort);
EXPECT_THROW(pass_manager.register_pass(assign_tensors), runtime_error);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "test_tools.hpp"
using namespace ngraph;
using namespace std;
static vector<pass::MemoryManager::node> get_node_list(const pass::MemoryManager& mm)
{
vector<pass::MemoryManager::node> rc;
rc.insert(rc.end(), mm.begin(), mm.end());
return rc;
}
TEST(memory_manager, allocate)
{
pass::MemoryManager mm{1};
// Special case, allocating size zero bumps the size of the alloc up to the alignment size
EXPECT_EQ(0, mm.allocate(0));
EXPECT_EQ(1, mm.allocate(10));
EXPECT_EQ(11, mm.allocate(10));
EXPECT_EQ(21, mm.allocate(10));
}
TEST(memory_manager, free_first_allocated)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(3, mm.get_node_list().size());
mm.free(0);
auto node_list = get_node_list(mm);
EXPECT_EQ(3, node_list.size());
EXPECT_TRUE(node_list[0].is_free());
EXPECT_FALSE(node_list[1].is_free());
EXPECT_TRUE(node_list[2].is_free());
}
TEST(memory_manager, free_middle_allocated)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(20, mm.allocate(10));
EXPECT_EQ(30, mm.allocate(10));
EXPECT_EQ(40, mm.allocate(10));
EXPECT_EQ(6, mm.get_node_list().size());
mm.free(10);
auto node_list = get_node_list(mm);
EXPECT_EQ(6, node_list.size());
EXPECT_FALSE(node_list[0].is_free());
EXPECT_TRUE(node_list[1].is_free());
EXPECT_FALSE(node_list[2].is_free());
EXPECT_FALSE(node_list[3].is_free());
EXPECT_FALSE(node_list[4].is_free());
}
TEST(memory_manager, free_last_allocated)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(20, mm.allocate(10));
EXPECT_EQ(30, mm.allocate(10));
EXPECT_EQ(40, mm.allocate(10));
EXPECT_EQ(6, mm.get_node_list().size());
mm.free(40);
auto node_list = get_node_list(mm);
EXPECT_EQ(5, node_list.size());
EXPECT_FALSE(node_list[0].is_free());
EXPECT_FALSE(node_list[1].is_free());
EXPECT_FALSE(node_list[2].is_free());
EXPECT_FALSE(node_list[3].is_free());
EXPECT_TRUE(node_list[4].is_free());
}
TEST(memory_manager, free_first_free)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(20, mm.allocate(10));
EXPECT_EQ(30, mm.allocate(10));
EXPECT_EQ(40, mm.allocate(10));
EXPECT_EQ(6, mm.get_node_list().size());
mm.free(10);
mm.free(0);
auto node_list = get_node_list(mm);
EXPECT_EQ(5, node_list.size());
EXPECT_TRUE(node_list[0].is_free());
EXPECT_FALSE(node_list[1].is_free());
EXPECT_FALSE(node_list[2].is_free());
EXPECT_FALSE(node_list[3].is_free());
}
TEST(memory_manager, free_middle_free)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(20, mm.allocate(10));
EXPECT_EQ(30, mm.allocate(10));
EXPECT_EQ(40, mm.allocate(10));
EXPECT_EQ(6, mm.get_node_list().size());
mm.free(0);
mm.free(20);
mm.free(10);
auto node_list = get_node_list(mm);
EXPECT_EQ(4, node_list.size());
EXPECT_TRUE(node_list[0].is_free());
EXPECT_FALSE(node_list[1].is_free());
EXPECT_FALSE(node_list[2].is_free());
}
TEST(memory_manager, max_allocated)
{
pass::MemoryManager mm{1};
EXPECT_EQ(0, mm.allocate(10));
EXPECT_EQ(10, mm.allocate(10));
EXPECT_EQ(20, mm.allocate(10));
EXPECT_EQ(30, mm.allocate(10));
EXPECT_EQ(40, mm.allocate(10));
EXPECT_EQ(6, mm.get_node_list().size());
mm.free(0);
mm.free(20);
mm.free(10);
EXPECT_EQ(mm.max_allocated(), 50);
}
TEST(memory_manager, bad_free)
{
pass::MemoryManager mm{1};
EXPECT_THROW(mm.free(10), std::runtime_error);
}
TEST(memory_manager, align)
{
EXPECT_EQ(8, pass::MemoryManager::align(0, 8));
EXPECT_EQ(8, pass::MemoryManager::align(1, 8));
EXPECT_EQ(8, pass::MemoryManager::align(2, 8));
EXPECT_EQ(8, pass::MemoryManager::align(3, 8));
EXPECT_EQ(8, pass::MemoryManager::align(4, 8));
EXPECT_EQ(8, pass::MemoryManager::align(5, 8));
EXPECT_EQ(8, pass::MemoryManager::align(6, 8));
EXPECT_EQ(8, pass::MemoryManager::align(7, 8));
EXPECT_EQ(8, pass::MemoryManager::align(8, 8));
EXPECT_EQ(16, pass::MemoryManager::align(9, 8));
}
TEST(memory_manager, memory_align)
{
pass::MemoryManager mm{64};
EXPECT_EQ(0, mm.allocate(4));
EXPECT_EQ(64, mm.allocate(4));
EXPECT_EQ(128, mm.allocate(4));
}
......@@ -55,12 +55,12 @@ bool validate_list(const list<Node*>& nodes)
shared_ptr<Node> make_test_graph()
{
auto arg_0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_4 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_5 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
auto arg_0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto arg_1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto arg_2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto arg_3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto arg_4 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto arg_5 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto t0 = make_shared<op::Add>(arg_0, arg_1);
auto t1 = make_shared<op::Dot>(t0, arg_2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment