Commit c9f896fc authored by Yixing Lao's avatar Yixing Lao Committed by GitHub

Merge pull request #130 from NervanaSystems/bob/new_directory

new directory layout
parents 1c78e9f3 036eafe0
......@@ -15,29 +15,29 @@ include_directories(SYSTEM ${EIGEN_INCLUDE_DIR})
set (SRC
log.cpp
ngraph/descriptor/input.cpp
ngraph/descriptor/output.cpp
ngraph/descriptor/tensor_view.cpp
ngraph/descriptor/tensor.cpp
ngraph/function.cpp
ngraph/node.cpp
ngraph/shape.cpp
ngraph/pass/assign_tensors.cpp
ngraph/pass/call_pass.cpp
ngraph/pass/dump_sorted.cpp
ngraph/pass/liveness.cpp
ngraph/pass/manager.cpp
ngraph/pass/memory_layout.cpp
ngraph/pass/pass.cpp
ngraph/pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp
ngraph/pass/visualize_tree.cpp
ngraph/runtime/call_frame.cpp
ngraph/runtime/eigen/external_function.cpp
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/visualize.cpp
descriptor/input.cpp
descriptor/output.cpp
descriptor/tensor_view.cpp
descriptor/tensor.cpp
function.cpp
node.cpp
shape.cpp
pass/assign_tensors.cpp
pass/call_pass.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
pass/manager.cpp
pass/memory_layout.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/tree_pass.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/eigen/external_function.cpp
runtime/eigen/tensor_view.cpp
shape.cpp
visualize.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
......
......@@ -17,8 +17,8 @@
#include <memory>
#include <vector>
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/function.hpp"
#include "descriptor/tensor_view.hpp"
#include "function.hpp"
namespace ngraph
{
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -16,7 +16,7 @@
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "descriptor/tensor.hpp"
namespace ngraph
{
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -17,7 +17,7 @@
#include <memory>
#include <set>
#include "ngraph/descriptor/tensor_view.hpp"
#include "descriptor/tensor_view.hpp"
namespace ngraph
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "descriptor/tensor.hpp"
#include "node.hpp"
namespace ngraph
{
namespace descriptor
{
Tensor::Tensor(const element::Type& element_type,
PrimaryTensorView* primary_tensor_view,
const Node* parent,
size_t value_index)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
, m_is_output{false}
, m_is_input{parent->is_parameter()}
, m_is_persistent{false}
, m_name{parent->get_node_id() + "_" + std::to_string(value_index)}
, m_next_view_id{0}
{
}
std::string Tensor::get_next_view_name()
{
return m_name + "_TV" + std::to_string(m_next_view_id++);
}
std::ostream& operator<<(std::ostream& out, const Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
return out;
}
}
}
......@@ -14,9 +14,9 @@
#pragma once
#include <iostream>
#include <memory>
#include <vector>
#include <iostream>
namespace ngraph
{
......@@ -40,19 +40,18 @@ namespace ngraph
Tensor(const Tensor&) = delete;
Tensor& operator=(const Tensor&) = delete;
Tensor(const element::Type& element_type, PrimaryTensorView* tensor_view,
const Node* parent, size_t value_index);
Tensor(const element::Type& element_type,
PrimaryTensorView* tensor_view,
const Node* parent,
size_t value_index);
std::string get_next_view_name();
public:
bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; }
bool is_output() const { return m_is_output; }
bool is_input() const { return m_is_input; }
bool is_persistent() const { return m_is_persistent; }
const std::string& get_name() const { return m_name; }
friend std::ostream& operator<<(std::ostream&, const Tensor&);
protected:
const element::Type& m_element_type;
PrimaryTensorView* m_primary_tensor_view;
......@@ -62,5 +61,7 @@ namespace ngraph
std::string m_name;
size_t m_next_view_id;
};
std::ostream& operator<<(std::ostream&, const Tensor&);
}
}
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor_view.hpp"
#include "descriptor/tensor_view.hpp"
using namespace ngraph;
using namespace descriptor;
......
......@@ -14,9 +14,9 @@
#pragma once
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
#include "descriptor/tensor.hpp"
#include "shape.hpp"
#include "types/type.hpp"
#include "log.hpp"
namespace ngraph
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/function.hpp"
#include "function.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -14,12 +14,12 @@
#pragma once
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/type.hpp"
#include "descriptor/tensor_view.hpp"
#include "node.hpp"
#include "ops/op.hpp"
#include "ops/parameter.hpp"
#include "runtime/instruction.hpp"
#include "types/type.hpp"
namespace ngraph
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//
// The public API for ngraph++
//
#pragma once
#include "common.hpp"
#include "descriptor/buffer.hpp"
#include "descriptor/call_frame.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "types/element_type.hpp"
#include "except.hpp"
#include "function.hpp"
#include "node.hpp"
#include "ops/abs.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp"
#include "ops/constant.hpp"
#include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp"
#include "ops/equal.hpp"
#include "ops/exp.hpp"
#include "ops/floor.hpp"
#include "ops/greater.hpp"
#include "ops/less.hpp"
#include "ops/log.hpp"
#include "ops/maximum.hpp"
#include "ops/minimum.hpp"
#include "ops/multiply.hpp"
#include "ops/negative.hpp"
#include "ops/op.hpp"
#include "ops/parameter.hpp"
#include "ops/power.hpp"
#include "ops/remainder.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp"
#include "runtime/eigen/add.hpp"
#include "runtime/eigen/external_function.hpp"
#include "runtime/eigen/multiply.hpp"
#include "runtime/eigen/return.hpp"
#include "runtime/eigen/tensor_view.hpp"
#include "runtime/call_frame.hpp"
#include "function.hpp"
#include "runtime/instruction.hpp"
#include "runtime/tensor_view.hpp"
#include "shape.hpp"
#include "types/type.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
using namespace ngraph;
using namespace descriptor;
Tensor::Tensor(const element::Type& element_type, PrimaryTensorView* primary_tensor_view,
const Node* parent, size_t value_index)
: m_element_type(element_type)
, m_primary_tensor_view(primary_tensor_view)
, m_is_output{false}
, m_is_input{parent->is_parameter()}
, m_is_persistent{false}
, m_name{parent->get_node_id()+"_"+std::to_string(value_index)}
, m_next_view_id{0}
{
}
std::string Tensor::get_next_view_name()
{
return m_name + "_TV" + std::to_string(m_next_view_id++);
}
std::ostream& ngraph::descriptor::operator<<(std::ostream& out, const Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
return out;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//
// The public API for ngraph++
//
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/call_frame.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view_layout.hpp"
#include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
......@@ -21,8 +21,8 @@
#include <iostream>
#include "ngraph/common.hpp"
#include "ngraph/type.hpp"
#include "common.hpp"
#include "types/type.hpp"
namespace ngraph
{
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "log.hpp"
using namespace std;
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace ngraph::op;
......
......@@ -16,8 +16,8 @@
#include <sstream>
#include "ngraph/element_type.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "types/element_type.hpp"
#include "runtime/eigen/tensor_view.hpp"
namespace ngraph
{
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -15,7 +15,7 @@
#include <algorithm>
#include <sstream>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace ngraph;
using namespace std;
......@@ -16,9 +16,9 @@
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/type.hpp"
#include "node.hpp"
#include "ops/parameter.hpp"
#include "types/type.hpp"
namespace ngraph
{
......
......@@ -14,7 +14,7 @@
#include <sstream>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -14,8 +14,8 @@
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/type.hpp"
#include "node.hpp"
#include "types/type.hpp"
namespace ngraph
{
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph::op;
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -18,7 +18,7 @@
#include <sstream>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "propagate_types.hpp"
using namespace std;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "call_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "call_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class CallBase;
}
class Node;
}
class ngraph::pass::CallBase : public Base
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>&) = 0;
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class CallBase;
}
class Node;
}
class ngraph::pass::CallBase : public Base
{
public:
virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>&) = 0;
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private:
};
......@@ -15,7 +15,7 @@
#include <fstream>
#include "dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "util.hpp"
using namespace ngraph;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include <unordered_set>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "util.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> currently_live;
for(auto it=ops.rbegin(); it!=ops.rend(); it++)
{
Node& exop = **it;
exop.liveness_live_list.clear();
exop.liveness_new_list.clear();
exop.liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls;
for (auto input_decl : exop.get_inputs())
{
descriptor::Tensor& tensor = input_decl.get_tensor();
if (is_temporary(tensor))
{
input_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> output_tensor_decls;
for (auto output_decl : exop.get_outputs())
{
descriptor::Tensor& tensor = output_decl.get_tensor();
if (is_temporary(tensor))
{
output_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> free_tensor_decls;
unordered_set<descriptor::Tensor*> new_tensor_decls;
unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
for (auto decls : {input_tensor_decls, output_tensor_decls})
{
for (descriptor::Tensor* tensor_decl : decls)
{
if (!contains(currently_live, tensor_decl))
{
// this is the last node that value is seen in
// delete it at the end of the op
currently_live.insert(tensor_decl);
free_tensor_decls.insert(tensor_decl);
}
}
}
exop.liveness_live_list = currently_live;
for (descriptor::Tensor* output_decl : output_tensor_decls)
{
if (contains(currently_live, output_decl))
{
new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl);
}
}
exop.liveness_free_list = free_tensor_decls;
exop.liveness_new_list = new_tensor_decls;
}
// Anything marked as output must remain live for the remainder of the graph
// Add outputs to live_list and remove from free_list
unordered_set<descriptor::Tensor*> outputs;
unordered_set<descriptor::Tensor*> seen;
for (Node* exop : ops)
{
for (descriptor::Tensor* tensor : exop->liveness_live_list)
{
if (tensor->is_output())
{
outputs.insert(tensor);
}
}
for (descriptor::Tensor* tensor : outputs)
{
exop->liveness_live_list.insert(tensor);
exop->liveness_free_list.erase(tensor);
if (contains(exop->liveness_new_list, tensor))
{
if (contains(seen, tensor))
{
exop->liveness_new_list.erase(tensor);
}
else
{
seen.insert(tensor);
}
}
}
}
validate_liveness(ops);
return false;
}
void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<AssignTensors>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
{
return
tensor.is_persistent() == false
&& tensor.is_input() == false
;
// && tensor.is_constant() == false
// && tensor.is_compile_only() == false;
}
void pass::Liveness::validate_liveness(const list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> dead_tensors;
for (const Node* exop : ops)
{
auto active = exop->liveness_live_list;
active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end());
active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
for (const descriptor::Tensor* tensor : active)
{
if (contains(dead_tensors, tensor))
{
throw runtime_error("Liveness: Dead tensors intersect active tensors");
}
}
dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include <unordered_set>
#include "log.hpp"
#include "ngraph.hpp"
#include "pass/assign_tensors.hpp"
#include "pass/liveness.hpp"
#include "util.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> currently_live;
for(auto it=ops.rbegin(); it!=ops.rend(); it++)
{
Node& exop = **it;
exop.liveness_live_list.clear();
exop.liveness_new_list.clear();
exop.liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls;
for (auto input_decl : exop.get_inputs())
{
descriptor::Tensor& tensor = input_decl.get_tensor();
if (is_temporary(tensor))
{
input_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> output_tensor_decls;
for (auto output_decl : exop.get_outputs())
{
descriptor::Tensor& tensor = output_decl.get_tensor();
if (is_temporary(tensor))
{
output_tensor_decls.insert(&tensor);
}
}
unordered_set<descriptor::Tensor*> free_tensor_decls;
unordered_set<descriptor::Tensor*> new_tensor_decls;
unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
for (auto decls : {input_tensor_decls, output_tensor_decls})
{
for (descriptor::Tensor* tensor_decl : decls)
{
if (!contains(currently_live, tensor_decl))
{
// this is the last node that value is seen in
// delete it at the end of the op
currently_live.insert(tensor_decl);
free_tensor_decls.insert(tensor_decl);
}
}
}
exop.liveness_live_list = currently_live;
for (descriptor::Tensor* output_decl : output_tensor_decls)
{
if (contains(currently_live, output_decl))
{
new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl);
}
}
exop.liveness_free_list = free_tensor_decls;
exop.liveness_new_list = new_tensor_decls;
}
// Anything marked as output must remain live for the remainder of the graph
// Add outputs to live_list and remove from free_list
unordered_set<descriptor::Tensor*> outputs;
unordered_set<descriptor::Tensor*> seen;
for (Node* exop : ops)
{
for (descriptor::Tensor* tensor : exop->liveness_live_list)
{
if (tensor->is_output())
{
outputs.insert(tensor);
}
}
for (descriptor::Tensor* tensor : outputs)
{
exop->liveness_live_list.insert(tensor);
exop->liveness_free_list.erase(tensor);
if (contains(exop->liveness_new_list, tensor))
{
if (contains(seen, tensor))
{
exop->liveness_new_list.erase(tensor);
}
else
{
seen.insert(tensor);
}
}
}
}
validate_liveness(ops);
return false;
}
void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<AssignTensors>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
{
return
tensor.is_persistent() == false
&& tensor.is_input() == false
;
// && tensor.is_constant() == false
// && tensor.is_compile_only() == false;
}
void pass::Liveness::validate_liveness(const list<Node*>& ops)
{
unordered_set<descriptor::Tensor*> dead_tensors;
for (const Node* exop : ops)
{
auto active = exop->liveness_live_list;
active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end());
active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
for (const descriptor::Tensor* tensor : active)
{
if (contains(dead_tensors, tensor))
{
throw runtime_error("Liveness: Dead tensors intersect active tensors");
}
}
dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp"
namespace ngraph
{
namespace pass
{
class Liveness;
}
class Node;
}
class ngraph::pass::Liveness : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
bool is_temporary(const descriptor::Tensor&);
void validate_liveness(const std::list<Node*>& ops);
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "call_pass.hpp"
#include "descriptor/tensor.hpp"
namespace ngraph
{
namespace pass
{
class Liveness;
}
class Node;
}
class ngraph::pass::Liveness : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
bool is_temporary(const descriptor::Tensor&);
void validate_liveness(const std::list<Node*>& ops);
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <iostream>
#include <memory>
#include "log.hpp"
#include "manager.hpp"
#include "ngraph/node.hpp"
using namespace std;
ngraph::pass::Manager::Manager()
{
}
ngraph::pass::Manager::~Manager()
{
}
void ngraph::pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_call_passes);
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes)
{
for (shared_ptr<TreeBase> p : m_tree_passes)
{
p->run_on_tree(nodes);
if (p->call_graph_produced())
{
m_sorted_list = p->get_call_graph();
}
}
for (shared_ptr<CallBase>& p : m_call_passes)
{
p->run_on_call_list(m_sorted_list);
}
}
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const
{
return m_sorted_list;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <iostream>
#include <memory>
#include "log.hpp"
#include "manager.hpp"
#include "node.hpp"
using namespace std;
ngraph::pass::Manager::Manager()
{
}
ngraph::pass::Manager::~Manager()
{
}
void ngraph::pass::Manager::initialize_default_passes()
{
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p);
}
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{
if (p == nullptr)
{
throw invalid_argument("null pass registered");
}
p->check_dependencies(m_call_passes);
m_call_passes.push_back(p);
}
void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes)
{
for (shared_ptr<TreeBase> p : m_tree_passes)
{
p->run_on_tree(nodes);
if (p->call_graph_produced())
{
m_sorted_list = p->get_call_graph();
}
}
for (shared_ptr<CallBase>& p : m_call_passes)
{
p->run_on_call_list(m_sorted_list);
}
}
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const
{
return m_sorted_list;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "call_pass.hpp"
#include "tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class Manager;
}
class Node;
}
class ngraph::pass::Manager
{
public:
Manager();
~Manager();
void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>);
void register_pass(std::shared_ptr<CallBase>);
void run_passes(std::shared_ptr<Node> nodes);
const std::list<Node*>& get_sorted_list() const;
private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "call_pass.hpp"
#include "tree_pass.hpp"
namespace ngraph
{
namespace pass
{
class Manager;
}
class Node;
}
class ngraph::pass::Manager
{
public:
Manager();
~Manager();
void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>);
void register_pass(std::shared_ptr<CallBase>);
void run_passes(std::shared_ptr<Node> nodes);
const std::list<Node*>& get_sorted_list() const;
private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include "log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
{
for (Node* node : node_list)
{
}
return false;
}
void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<Liveness>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size}
, m_state{state}
{
}
pass::MemoryManager::MemoryManager(size_t alignment)
: m_alignment{alignment}
, m_scheme{allocation_scheme::BEST_FIT}
, m_max_allocated{0}
{
// assert(m_base_offset % m_alignment == 0);
m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE);
}
size_t pass::MemoryManager::allocate(size_t size)
{
size_t rc;
switch(m_scheme)
{
case allocation_scheme::FIRST_FIT:
rc = first_fit(size);
break;
case allocation_scheme::BEST_FIT:
rc = best_fit(size);
break;
}
return rc;
}
size_t pass::MemoryManager::best_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end();
size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
size_t delta = it->m_size - size;
if (delta < min_delta)
{
min_delta = delta;
best_fit = it;
best_offset = offset;
}
}
offset += it->m_size;
}
if (best_fit == m_node_list.end())
{
throw bad_alloc();
}
if (min_delta == 0)
{
// exact fit
best_fit->m_state = block_state::ALLOCATED;
}
else
{
m_node_list.insert(best_fit, node{size, block_state::ALLOCATED});
best_fit->m_size -= size;
}
m_max_allocated = std::max(m_max_allocated, best_offset + size);
return best_offset;
}
size_t pass::MemoryManager::first_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
if (it->m_size > size)
{
m_node_list.insert(it, node{size, block_state::ALLOCATED});
it->m_size -= size;
}
else
{
// exact fit
it->m_state = block_state::ALLOCATED;
}
found = true;
break;
}
offset += it->m_size;
}
if (!found)
{
throw bad_alloc();
}
m_max_allocated = std::max(m_max_allocated, offset + size);
return offset;
}
void pass::MemoryManager::free(size_t offset)
{
size_t search_offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (offset == search_offset)
{
list<node>::iterator it_next = std::next(it);
if (it == m_node_list.begin())
{
// free the first node in the list
it->m_state = block_state::FREE;
}
else
{
// node has predecessor
list<node>::iterator it_prev = std::prev(it);
if (it_prev->m_state == block_state::FREE)
{
it->m_size += it_prev->m_size;
m_node_list.erase(it_prev);
}
}
if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE)
{
// join this node with next
it->m_size += it_next->m_size;
m_node_list.erase(it_next);
}
it->m_state = block_state::FREE;
found = true;
break;
}
search_offset += it->m_size;
}
if (!found)
{
throw runtime_error("bad free");
}
}
void pass::MemoryManager::dump(std::ostream& out)
{
for (const node& n : m_node_list)
{
out << "size=" << n.m_size << ", ";
out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED");
out << "\n";
}
}
size_t pass::MemoryManager::align(size_t size, size_t alignment)
{
if (size == 0)
{
size = alignment;
}
else
{
auto remainder = size % alignment;
if (remainder > 0)
{
size += (alignment - remainder);
}
}
return size;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <exception>
#include <sstream>
#include "log.hpp"
#include "ngraph.hpp"
#include "pass/liveness.hpp"
#include "pass/memory_layout.hpp"
#include "log.hpp"
using namespace std;
using namespace ngraph;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
{
for (Node* node : node_list)
{
}
return false;
}
void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{
bool found_propagate_types = false;
for (auto pass : registered_passes)
{
if (dynamic_pointer_cast<Liveness>(pass))
{
found_propagate_types = true;
}
}
if (!found_propagate_types)
{
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
}
}
pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size}
, m_state{state}
{
}
pass::MemoryManager::MemoryManager(size_t alignment)
: m_alignment{alignment}
, m_scheme{allocation_scheme::BEST_FIT}
, m_max_allocated{0}
{
// assert(m_base_offset % m_alignment == 0);
m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE);
}
size_t pass::MemoryManager::allocate(size_t size)
{
size_t rc;
switch(m_scheme)
{
case allocation_scheme::FIRST_FIT:
rc = first_fit(size);
break;
case allocation_scheme::BEST_FIT:
rc = best_fit(size);
break;
}
return rc;
}
size_t pass::MemoryManager::best_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end();
size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
size_t delta = it->m_size - size;
if (delta < min_delta)
{
min_delta = delta;
best_fit = it;
best_offset = offset;
}
}
offset += it->m_size;
}
if (best_fit == m_node_list.end())
{
throw bad_alloc();
}
if (min_delta == 0)
{
// exact fit
best_fit->m_state = block_state::ALLOCATED;
}
else
{
m_node_list.insert(best_fit, node{size, block_state::ALLOCATED});
best_fit->m_size -= size;
}
m_max_allocated = std::max(m_max_allocated, best_offset + size);
return best_offset;
}
size_t pass::MemoryManager::first_fit(size_t size)
{
size = align(size, m_alignment);
size_t offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (it->m_state == block_state::FREE && it->m_size >= size)
{
if (it->m_size > size)
{
m_node_list.insert(it, node{size, block_state::ALLOCATED});
it->m_size -= size;
}
else
{
// exact fit
it->m_state = block_state::ALLOCATED;
}
found = true;
break;
}
offset += it->m_size;
}
if (!found)
{
throw bad_alloc();
}
m_max_allocated = std::max(m_max_allocated, offset + size);
return offset;
}
void pass::MemoryManager::free(size_t offset)
{
size_t search_offset = 0;
bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{
if (offset == search_offset)
{
list<node>::iterator it_next = std::next(it);
if (it == m_node_list.begin())
{
// free the first node in the list
it->m_state = block_state::FREE;
}
else
{
// node has predecessor
list<node>::iterator it_prev = std::prev(it);
if (it_prev->m_state == block_state::FREE)
{
it->m_size += it_prev->m_size;
m_node_list.erase(it_prev);
}
}
if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE)
{
// join this node with next
it->m_size += it_next->m_size;
m_node_list.erase(it_next);
}
it->m_state = block_state::FREE;
found = true;
break;
}
search_offset += it->m_size;
}
if (!found)
{
throw runtime_error("bad free");
}
}
void pass::MemoryManager::dump(std::ostream& out)
{
for (const node& n : m_node_list)
{
out << "size=" << n.m_size << ", ";
out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED");
out << "\n";
}
}
size_t pass::MemoryManager::align(size_t size, size_t alignment)
{
if (size == 0)
{
size = alignment;
}
else
{
auto remainder = size % alignment;
if (remainder > 0)
{
size += (alignment - remainder);
}
}
return size;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <limits>
#include <list>
#include <sstream>
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class MemoryLayout;
class MemoryNode;
class MemoryManager;
}
class Node;
}
class ngraph::pass::MemoryLayout : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
};
class ngraph::pass::MemoryManager
{
public:
enum class block_state
{
FREE,
ALLOCATED
};
enum class allocation_scheme
{
FIRST_FIT,
BEST_FIT
};
class node
{
public:
node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; }
size_t m_size;
block_state m_state;
};
MemoryManager(size_t alignment=1);
// memory_manager& alignment(size_t a);
size_t allocate(size_t size);
void free(size_t offset);
void dump(std::ostream&);
static size_t align(size_t x, size_t alignment);
std::list<node>::iterator begin() { return m_node_list.begin(); }
std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; }
private:
size_t first_fit(size_t size);
size_t best_fit(size_t size);
std::list<node> m_node_list;
size_t m_alignment;
allocation_scheme m_scheme;
size_t m_max_allocated;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <limits>
#include <list>
#include <sstream>
#include "call_pass.hpp"
namespace ngraph
{
namespace pass
{
class MemoryLayout;
class MemoryNode;
class MemoryManager;
}
class Node;
}
class ngraph::pass::MemoryLayout : public CallBase
{
public:
virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private:
};
class ngraph::pass::MemoryManager
{
public:
enum class block_state
{
FREE,
ALLOCATED
};
enum class allocation_scheme
{
FIRST_FIT,
BEST_FIT
};
class node
{
public:
node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; }
size_t m_size;
block_state m_state;
};
MemoryManager(size_t alignment=1);
// memory_manager& alignment(size_t a);
size_t allocate(size_t size);
void free(size_t offset);
void dump(std::ostream&);
static size_t align(size_t x, size_t alignment);
std::list<node>::iterator begin() { return m_node_list.begin(); }
std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; }
private:
size_t first_fit(size_t size);
size_t best_fit(size_t size);
std::list<node> m_node_list;
size_t m_alignment;
allocation_scheme m_scheme;
size_t m_max_allocated;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace pass
{
class Base;
}
}
class ngraph::pass::Base
{
public:
private:
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace pass
{
class Base;
}
}
class ngraph::pass::Base
{
public:
private:
};
......@@ -14,7 +14,7 @@
#include <sstream>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "propagate_types.hpp"
using namespace std;
......
......@@ -16,8 +16,8 @@
#include <unordered_map>
#include "log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "node.hpp"
#include "pass/topological_sort.hpp"
#include "util.hpp"
using namespace ngraph;
......
......@@ -17,7 +17,7 @@
#include <list>
#include <memory>
#include "ngraph/pass/tree_pass.hpp"
#include "pass/tree_pass.hpp"
namespace ngraph
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "tree_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "tree_pass.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class TreeBase;
}
class Node;
}
class ngraph::pass::TreeBase : public Base
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {}
private:
std::list<Node*> m_sorted_list;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "pass.hpp"
namespace ngraph
{
namespace pass
{
class TreeBase;
}
class Node;
}
class ngraph::pass::TreeBase : public Base
{
public:
virtual ~TreeBase() {}
// return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
// derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {}
private:
std::list<Node*> m_sorted_list;
};
......@@ -15,7 +15,7 @@
#include <fstream>
#include "visualize_tree.hpp"
#include "ngraph/node.hpp"
#include "node.hpp"
#include "util.hpp"
using namespace ngraph;
......
......@@ -18,7 +18,7 @@
#include <string>
#include <set>
#include "ngraph/pass/tree_pass.hpp"
#include "pass/tree_pass.hpp"
namespace ngraph
{
......
......@@ -14,7 +14,7 @@
#include <algorithm>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -17,9 +17,9 @@
#include <memory>
#include <vector>
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "runtime/tensor_view.hpp"
#include "function.hpp"
#include "runtime/instruction.hpp"
namespace ngraph
{
......
......@@ -14,9 +14,9 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "runtime/call_frame.hpp"
#include "runtime/eigen/tensor_view.hpp"
#include "runtime/instruction.hpp"
namespace ngraph
{
......
......@@ -18,17 +18,17 @@
#include <typeinfo>
#include <unordered_map>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "function.hpp"
#include "node.hpp"
#include "ops/add.hpp"
#include "ops/multiply.hpp"
#include "pass/topological_sort.hpp"
#include "runtime/eigen/add.hpp"
#include "runtime/eigen/external_function.hpp"
#include "runtime/eigen/multiply.hpp"
#include "runtime/eigen/return.hpp"
using namespace std;
using namespace ngraph::runtime::eigen;
......
......@@ -19,7 +19,7 @@
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "function.hpp"
namespace ngraph
{
......
......@@ -14,8 +14,8 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "runtime/call_frame.hpp"
#include "runtime/instruction.hpp"
namespace ngraph
{
......
......@@ -14,8 +14,8 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "runtime/call_frame.hpp"
#include "runtime/instruction.hpp"
namespace ngraph
{
......
......@@ -17,9 +17,9 @@
#include <Eigen/Dense>
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "shape.hpp"
#include "runtime/tensor_view.hpp"
#include "descriptor/tensor_view.hpp"
namespace ngraph
{
......
......@@ -15,7 +15,7 @@
#include <algorithm>
#include <vector>
#include "ngraph/shape.hpp"
#include "shape.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -16,7 +16,7 @@
#include <cmath>
#include <iostream>
#include "ngraph/element_type.hpp"
#include "element_type.hpp"
#include "log.hpp"
using namespace ngraph;
......
......@@ -22,7 +22,7 @@
#include <string>
#include <type_traits>
#include "ngraph/except.hpp"
#include "except.hpp"
namespace ngraph
{
......
......@@ -14,7 +14,7 @@
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "log.hpp"
#include "util.hpp"
......
......@@ -17,8 +17,8 @@
#include <memory>
#include <vector>
#include "ngraph/element_type.hpp"
#include "ngraph/shape.hpp"
#include "types/element_type.hpp"
#include "shape.hpp"
namespace ngraph
{
......
......@@ -19,7 +19,7 @@
#include <unordered_set>
#include "util.hpp"
#include "ngraph/node.hpp"
#include "node.hpp"
#include "log.hpp"
using namespace std;
......
......@@ -16,8 +16,8 @@
#include <fstream>
#include <list>
#include "ngraph/node.hpp"
#include "ngraph/visualize.hpp"
#include "node.hpp"
#include "visualize.hpp"
#include "util.hpp"
using namespace ngraph;
......
......@@ -14,7 +14,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include <memory>
using namespace std;
......
......@@ -18,6 +18,6 @@
#include "gtest/gtest.h"
#include "ngraph/element_type.hpp"
#include "types/element_type.hpp"
using namespace ngraph;
......@@ -14,7 +14,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -14,7 +14,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include <memory>
using namespace std;
......
......@@ -16,7 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -19,15 +19,15 @@
#include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp"
#include "pass/liveness.hpp"
#include "pass/assign_tensors.hpp"
#include "pass/manager.hpp"
#include "pass/propagate_types.hpp"
#include "pass/topological_sort.hpp"
#include "pass/liveness.hpp"
#include "pass/visualize_tree.hpp"
#include "pass/dump_sorted.hpp"
#include "ngraph.hpp"
#include "test_tools.hpp"
#include "log.hpp"
......
......@@ -19,11 +19,11 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph.hpp"
#include "pass/assign_tensors.hpp"
#include "pass/manager.hpp"
#include "pass/propagate_types.hpp"
#include "pass/topological_sort.hpp"
#include "test_tools.hpp"
using namespace ngraph;
......
......@@ -19,8 +19,8 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph.hpp"
#include "pass/memory_layout.hpp"
#include "test_tools.hpp"
using namespace ngraph;
......
......@@ -17,7 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -16,7 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -15,7 +15,7 @@
#include <algorithm>
#include "test_tools.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include "util.hpp"
using namespace std;
......
......@@ -19,9 +19,9 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/visualize.hpp"
#include "ngraph.hpp"
#include "pass/topological_sort.hpp"
#include "visualize.hpp"
#include "util.hpp"
#include "log.hpp"
#include "test_tools.hpp"
......
......@@ -14,7 +14,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph.hpp"
#include <memory>
using namespace std;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment