Commit c7cc170c authored by Bob Kimball's avatar Bob Kimball

new directory layout

parent 1c78e9f3
...@@ -15,29 +15,29 @@ include_directories(SYSTEM ${EIGEN_INCLUDE_DIR}) ...@@ -15,29 +15,29 @@ include_directories(SYSTEM ${EIGEN_INCLUDE_DIR})
set (SRC set (SRC
log.cpp log.cpp
ngraph/descriptor/input.cpp descriptor/input.cpp
ngraph/descriptor/output.cpp descriptor/output.cpp
ngraph/descriptor/tensor_view.cpp descriptor/tensor_view.cpp
ngraph/descriptor/tensor.cpp descriptor/tensor.cpp
ngraph/function.cpp function.cpp
ngraph/node.cpp node.cpp
ngraph/shape.cpp shape.cpp
ngraph/pass/assign_tensors.cpp pass/assign_tensors.cpp
ngraph/pass/call_pass.cpp pass/call_pass.cpp
ngraph/pass/dump_sorted.cpp pass/dump_sorted.cpp
ngraph/pass/liveness.cpp pass/liveness.cpp
ngraph/pass/manager.cpp pass/manager.cpp
ngraph/pass/memory_layout.cpp pass/memory_layout.cpp
ngraph/pass/pass.cpp pass/pass.cpp
ngraph/pass/propagate_types.cpp pass/propagate_types.cpp
ngraph/pass/topological_sort.cpp pass/topological_sort.cpp
ngraph/pass/tree_pass.cpp pass/tree_pass.cpp
ngraph/pass/visualize_tree.cpp pass/visualize_tree.cpp
ngraph/runtime/call_frame.cpp runtime/call_frame.cpp
ngraph/runtime/eigen/external_function.cpp runtime/eigen/external_function.cpp
ngraph/runtime/eigen/tensor_view.cpp runtime/eigen/tensor_view.cpp
ngraph/shape.cpp shape.cpp
ngraph/visualize.cpp visualize.cpp
ops/binary_elementwise_arithmetic.cpp ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp ops/binary_elementwise_comparison.cpp
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/descriptor/tensor_view.hpp" #include "descriptor/tensor_view.hpp"
#include "ngraph/function.hpp" #include "function.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <memory> #include <memory>
#include "ngraph/descriptor/tensor.hpp" #include "descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include "ngraph/descriptor/tensor_view.hpp" #include "descriptor/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor.hpp" #include "descriptor/tensor.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/descriptor/tensor_view.hpp" #include "descriptor/tensor_view.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "ngraph/descriptor/tensor.hpp" #include "descriptor/tensor.hpp"
#include "ngraph/shape.hpp" #include "shape.hpp"
#include "ngraph/type.hpp" #include "types/type.hpp"
#include "log.hpp" #include "log.hpp"
namespace ngraph namespace ngraph
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/function.hpp" #include "function.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#pragma once #pragma once
#include "ngraph/descriptor/tensor_view.hpp" #include "descriptor/tensor_view.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/op.hpp" #include "ops/op.hpp"
#include "ngraph/ops/parameter.hpp" #include "ops/parameter.hpp"
#include "ngraph/runtime/instruction.hpp" #include "runtime/instruction.hpp"
#include "ngraph/type.hpp" #include "types/type.hpp"
namespace ngraph namespace ngraph
{ {
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//
// The public API for ngraph++
//
#pragma once
#include "common.hpp"
#include "descriptor/buffer.hpp"
#include "descriptor/call_frame.hpp"
#include "descriptor/input.hpp"
#include "descriptor/output.hpp"
#include "descriptor/tensor.hpp"
#include "descriptor/tensor_view.hpp"
#include "descriptor/tensor_view_layout.hpp"
#include "types/element_type.hpp"
#include "except.hpp"
#include "function.hpp"
#include "node.hpp"
#include "ops/abs.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp"
#include "ops/constant.hpp"
#include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp"
#include "ops/equal.hpp"
#include "ops/exp.hpp"
#include "ops/floor.hpp"
#include "ops/greater.hpp"
#include "ops/less.hpp"
#include "ops/log.hpp"
#include "ops/maximum.hpp"
#include "ops/minimum.hpp"
#include "ops/multiply.hpp"
#include "ops/negative.hpp"
#include "ops/op.hpp"
#include "ops/parameter.hpp"
#include "ops/power.hpp"
#include "ops/remainder.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp"
#include "runtime/eigen/add.hpp"
#include "runtime/eigen/external_function.hpp"
#include "runtime/eigen/multiply.hpp"
#include "runtime/eigen/return.hpp"
#include "runtime/eigen/tensor_view.hpp"
#include "runtime/call_frame.hpp"
#include "function.hpp"
#include "runtime/instruction.hpp"
#include "runtime/tensor_view.hpp"
#include "shape.hpp"
#include "types/type.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
//
// The public API for ngraph++
//
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/call_frame.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view_layout.hpp"
#include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type.hpp"
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include <iostream> #include <iostream>
#include "ngraph/common.hpp" #include "common.hpp"
#include "ngraph/type.hpp" #include "types/type.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "log.hpp" #include "log.hpp"
using namespace std; using namespace std;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <sstream> #include <sstream>
#include "ngraph/element_type.hpp" #include "types/element_type.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp" #include "runtime/eigen/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <algorithm> #include <algorithm>
#include <sstream> #include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/ops/parameter.hpp" #include "ops/parameter.hpp"
#include "ngraph/type.hpp" #include "types/type.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <sstream> #include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/type.hpp" #include "types/type.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <sstream> #include <sstream>
#include "log.hpp" #include "log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "propagate_types.hpp" #include "propagate_types.hpp"
using namespace std; using namespace std;
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "call_pass.hpp" #include "call_pass.hpp"
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
#include <list> #include <list>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "pass.hpp" #include "pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class CallBase; class CallBase;
} }
class Node; class Node;
} }
class ngraph::pass::CallBase : public Base class ngraph::pass::CallBase : public Base
{ {
public: public:
virtual ~CallBase() {} virtual ~CallBase() {}
virtual bool run_on_call_list(std::list<Node*>&) = 0; virtual bool run_on_call_list(std::list<Node*>&) = 0;
// derived class throws exception if its dependencies have not been met // derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {} virtual void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const {}
private: private:
}; };
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <fstream> #include <fstream>
#include "dump_sorted.hpp" #include "dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "log.hpp" #include "log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "pass/liveness.hpp"
#include "util.hpp" #include "util.hpp"
#include "log.hpp" #include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::Liveness::run_on_call_list(list<Node*>& ops) bool pass::Liveness::run_on_call_list(list<Node*>& ops)
{ {
unordered_set<descriptor::Tensor*> currently_live; unordered_set<descriptor::Tensor*> currently_live;
for(auto it=ops.rbegin(); it!=ops.rend(); it++) for(auto it=ops.rbegin(); it!=ops.rend(); it++)
{ {
Node& exop = **it; Node& exop = **it;
exop.liveness_live_list.clear(); exop.liveness_live_list.clear();
exop.liveness_new_list.clear(); exop.liveness_new_list.clear();
exop.liveness_free_list.clear(); exop.liveness_free_list.clear();
unordered_set<descriptor::Tensor*> input_tensor_decls; unordered_set<descriptor::Tensor*> input_tensor_decls;
for (auto input_decl : exop.get_inputs()) for (auto input_decl : exop.get_inputs())
{ {
descriptor::Tensor& tensor = input_decl.get_tensor(); descriptor::Tensor& tensor = input_decl.get_tensor();
if (is_temporary(tensor)) if (is_temporary(tensor))
{ {
input_tensor_decls.insert(&tensor); input_tensor_decls.insert(&tensor);
} }
} }
unordered_set<descriptor::Tensor*> output_tensor_decls; unordered_set<descriptor::Tensor*> output_tensor_decls;
for (auto output_decl : exop.get_outputs()) for (auto output_decl : exop.get_outputs())
{ {
descriptor::Tensor& tensor = output_decl.get_tensor(); descriptor::Tensor& tensor = output_decl.get_tensor();
if (is_temporary(tensor)) if (is_temporary(tensor))
{ {
output_tensor_decls.insert(&tensor); output_tensor_decls.insert(&tensor);
} }
} }
unordered_set<descriptor::Tensor*> free_tensor_decls; unordered_set<descriptor::Tensor*> free_tensor_decls;
unordered_set<descriptor::Tensor*> new_tensor_decls; unordered_set<descriptor::Tensor*> new_tensor_decls;
unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls; unordered_set<descriptor::Tensor*> all_tensor_decls = input_tensor_decls;
for (auto decls : {input_tensor_decls, output_tensor_decls}) for (auto decls : {input_tensor_decls, output_tensor_decls})
{ {
for (descriptor::Tensor* tensor_decl : decls) for (descriptor::Tensor* tensor_decl : decls)
{ {
if (!contains(currently_live, tensor_decl)) if (!contains(currently_live, tensor_decl))
{ {
// this is the last node that value is seen in // this is the last node that value is seen in
// delete it at the end of the op // delete it at the end of the op
currently_live.insert(tensor_decl); currently_live.insert(tensor_decl);
free_tensor_decls.insert(tensor_decl); free_tensor_decls.insert(tensor_decl);
} }
} }
} }
exop.liveness_live_list = currently_live; exop.liveness_live_list = currently_live;
for (descriptor::Tensor* output_decl : output_tensor_decls) for (descriptor::Tensor* output_decl : output_tensor_decls)
{ {
if (contains(currently_live, output_decl)) if (contains(currently_live, output_decl))
{ {
new_tensor_decls.insert(output_decl); new_tensor_decls.insert(output_decl);
currently_live.erase(output_decl); currently_live.erase(output_decl);
} }
} }
exop.liveness_free_list = free_tensor_decls; exop.liveness_free_list = free_tensor_decls;
exop.liveness_new_list = new_tensor_decls; exop.liveness_new_list = new_tensor_decls;
} }
// Anything marked as output must remain live for the remainder of the graph // Anything marked as output must remain live for the remainder of the graph
// Add outputs to live_list and remove from free_list // Add outputs to live_list and remove from free_list
unordered_set<descriptor::Tensor*> outputs; unordered_set<descriptor::Tensor*> outputs;
unordered_set<descriptor::Tensor*> seen; unordered_set<descriptor::Tensor*> seen;
for (Node* exop : ops) for (Node* exop : ops)
{ {
for (descriptor::Tensor* tensor : exop->liveness_live_list) for (descriptor::Tensor* tensor : exop->liveness_live_list)
{ {
if (tensor->is_output()) if (tensor->is_output())
{ {
outputs.insert(tensor); outputs.insert(tensor);
} }
} }
for (descriptor::Tensor* tensor : outputs) for (descriptor::Tensor* tensor : outputs)
{ {
exop->liveness_live_list.insert(tensor); exop->liveness_live_list.insert(tensor);
exop->liveness_free_list.erase(tensor); exop->liveness_free_list.erase(tensor);
if (contains(exop->liveness_new_list, tensor)) if (contains(exop->liveness_new_list, tensor))
{ {
if (contains(seen, tensor)) if (contains(seen, tensor))
{ {
exop->liveness_new_list.erase(tensor); exop->liveness_new_list.erase(tensor);
} }
else else
{ {
seen.insert(tensor); seen.insert(tensor);
} }
} }
} }
} }
validate_liveness(ops); validate_liveness(ops);
return false; return false;
} }
void pass::Liveness::check_dependencies( void pass::Liveness::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{ {
bool found_propagate_types = false; bool found_propagate_types = false;
for (auto pass : registered_passes) for (auto pass : registered_passes)
{ {
if (dynamic_pointer_cast<AssignTensors>(pass)) if (dynamic_pointer_cast<AssignTensors>(pass))
{ {
found_propagate_types = true; found_propagate_types = true;
} }
} }
if (!found_propagate_types) if (!found_propagate_types)
{ {
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'"); throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
} }
} }
bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor) bool pass::Liveness::is_temporary(const descriptor::Tensor& tensor)
{ {
return return
tensor.is_persistent() == false tensor.is_persistent() == false
&& tensor.is_input() == false && tensor.is_input() == false
; ;
// && tensor.is_constant() == false // && tensor.is_constant() == false
// && tensor.is_compile_only() == false; // && tensor.is_compile_only() == false;
} }
void pass::Liveness::validate_liveness(const list<Node*>& ops) void pass::Liveness::validate_liveness(const list<Node*>& ops)
{ {
unordered_set<descriptor::Tensor*> dead_tensors; unordered_set<descriptor::Tensor*> dead_tensors;
for (const Node* exop : ops) for (const Node* exop : ops)
{ {
auto active = exop->liveness_live_list; auto active = exop->liveness_live_list;
active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end()); active.insert(exop->liveness_new_list.begin(), exop->liveness_new_list.end());
active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end()); active.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
for (const descriptor::Tensor* tensor : active) for (const descriptor::Tensor* tensor : active)
{ {
if (contains(dead_tensors, tensor)) if (contains(dead_tensors, tensor))
{ {
throw runtime_error("Liveness: Dead tensors intersect active tensors"); throw runtime_error("Liveness: Dead tensors intersect active tensors");
} }
} }
dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end()); dead_tensors.insert(exop->liveness_free_list.begin(), exop->liveness_free_list.end());
} }
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
#include "call_pass.hpp" #include "call_pass.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "descriptor/tensor.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class Liveness; class Liveness;
} }
class Node; class Node;
} }
class ngraph::pass::Liveness : public CallBase class ngraph::pass::Liveness : public CallBase
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override; void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
bool is_temporary(const descriptor::Tensor&); bool is_temporary(const descriptor::Tensor&);
void validate_liveness(const std::list<Node*>& ops); void validate_liveness(const std::list<Node*>& ops);
}; };
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include "log.hpp" #include "log.hpp"
#include "manager.hpp" #include "manager.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
using namespace std; using namespace std;
ngraph::pass::Manager::Manager() ngraph::pass::Manager::Manager()
{ {
} }
ngraph::pass::Manager::~Manager() ngraph::pass::Manager::~Manager()
{ {
} }
void ngraph::pass::Manager::initialize_default_passes() void ngraph::pass::Manager::initialize_default_passes()
{ {
} }
void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p) void ngraph::pass::Manager::register_pass(std::shared_ptr<TreeBase> p)
{ {
if (p == nullptr) if (p == nullptr)
{ {
throw invalid_argument("null pass registered"); throw invalid_argument("null pass registered");
} }
p->check_dependencies(m_tree_passes); p->check_dependencies(m_tree_passes);
m_tree_passes.push_back(p); m_tree_passes.push_back(p);
} }
void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p) void ngraph::pass::Manager::register_pass(std::shared_ptr<CallBase> p)
{ {
if (p == nullptr) if (p == nullptr)
{ {
throw invalid_argument("null pass registered"); throw invalid_argument("null pass registered");
} }
p->check_dependencies(m_call_passes); p->check_dependencies(m_call_passes);
m_call_passes.push_back(p); m_call_passes.push_back(p);
} }
void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes) void ngraph::pass::Manager::run_passes(std::shared_ptr<Node> nodes)
{ {
for (shared_ptr<TreeBase> p : m_tree_passes) for (shared_ptr<TreeBase> p : m_tree_passes)
{ {
p->run_on_tree(nodes); p->run_on_tree(nodes);
if (p->call_graph_produced()) if (p->call_graph_produced())
{ {
m_sorted_list = p->get_call_graph(); m_sorted_list = p->get_call_graph();
} }
} }
for (shared_ptr<CallBase>& p : m_call_passes) for (shared_ptr<CallBase>& p : m_call_passes)
{ {
p->run_on_call_list(m_sorted_list); p->run_on_call_list(m_sorted_list);
} }
} }
const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const const std::list<ngraph::Node*>& ngraph::pass::Manager::get_sorted_list() const
{ {
return m_sorted_list; return m_sorted_list;
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
#include <vector> #include <vector>
#include "call_pass.hpp" #include "call_pass.hpp"
#include "tree_pass.hpp" #include "tree_pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class Manager; class Manager;
} }
class Node; class Node;
} }
class ngraph::pass::Manager class ngraph::pass::Manager
{ {
public: public:
Manager(); Manager();
~Manager(); ~Manager();
void initialize_default_passes(); void initialize_default_passes();
void register_pass(std::shared_ptr<TreeBase>); void register_pass(std::shared_ptr<TreeBase>);
void register_pass(std::shared_ptr<CallBase>); void register_pass(std::shared_ptr<CallBase>);
void run_passes(std::shared_ptr<Node> nodes); void run_passes(std::shared_ptr<Node> nodes);
const std::list<Node*>& get_sorted_list() const; const std::list<Node*>& get_sorted_list() const;
private: private:
std::vector<std::shared_ptr<TreeBase>> m_tree_passes; std::vector<std::shared_ptr<TreeBase>> m_tree_passes;
std::vector<std::shared_ptr<CallBase>> m_call_passes; std::vector<std::shared_ptr<CallBase>> m_call_passes;
std::list<Node*> m_sorted_list; std::list<Node*> m_sorted_list;
}; };
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <exception> #include <exception>
#include <sstream> #include <sstream>
#include "log.hpp" #include "log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "ngraph/pass/liveness.hpp" #include "pass/liveness.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "pass/memory_layout.hpp"
#include "log.hpp" #include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list) bool pass::MemoryLayout::run_on_call_list(std::list<Node*>& node_list)
{ {
for (Node* node : node_list) for (Node* node : node_list)
{ {
} }
return false; return false;
} }
void pass::MemoryLayout::check_dependencies( void pass::MemoryLayout::check_dependencies(
const std::vector<std::shared_ptr<CallBase>>& registered_passes) const const std::vector<std::shared_ptr<CallBase>>& registered_passes) const
{ {
bool found_propagate_types = false; bool found_propagate_types = false;
for (auto pass : registered_passes) for (auto pass : registered_passes)
{ {
if (dynamic_pointer_cast<Liveness>(pass)) if (dynamic_pointer_cast<Liveness>(pass))
{ {
found_propagate_types = true; found_propagate_types = true;
} }
} }
if (!found_propagate_types) if (!found_propagate_types)
{ {
throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'"); throw runtime_error("Dependency 'PropagateTypes' not found for pass 'AssignTensors'");
} }
} }
pass::MemoryManager::node::node(size_t size, block_state state) pass::MemoryManager::node::node(size_t size, block_state state)
: m_size{size} : m_size{size}
, m_state{state} , m_state{state}
{ {
} }
pass::MemoryManager::MemoryManager(size_t alignment) pass::MemoryManager::MemoryManager(size_t alignment)
: m_alignment{alignment} : m_alignment{alignment}
, m_scheme{allocation_scheme::BEST_FIT} , m_scheme{allocation_scheme::BEST_FIT}
, m_max_allocated{0} , m_max_allocated{0}
{ {
// assert(m_base_offset % m_alignment == 0); // assert(m_base_offset % m_alignment == 0);
m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE); m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE);
} }
size_t pass::MemoryManager::allocate(size_t size) size_t pass::MemoryManager::allocate(size_t size)
{ {
size_t rc; size_t rc;
switch(m_scheme) switch(m_scheme)
{ {
case allocation_scheme::FIRST_FIT: case allocation_scheme::FIRST_FIT:
rc = first_fit(size); rc = first_fit(size);
break; break;
case allocation_scheme::BEST_FIT: case allocation_scheme::BEST_FIT:
rc = best_fit(size); rc = best_fit(size);
break; break;
} }
return rc; return rc;
} }
size_t pass::MemoryManager::best_fit(size_t size) size_t pass::MemoryManager::best_fit(size_t size)
{ {
size = align(size, m_alignment); size = align(size, m_alignment);
size_t offset = 0; size_t offset = 0;
size_t min_delta = numeric_limits<size_t>::max(); size_t min_delta = numeric_limits<size_t>::max();
auto best_fit = m_node_list.end(); auto best_fit = m_node_list.end();
size_t best_offset = offset; size_t best_offset = offset;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
size_t delta = it->m_size - size; size_t delta = it->m_size - size;
if (delta < min_delta) if (delta < min_delta)
{ {
min_delta = delta; min_delta = delta;
best_fit = it; best_fit = it;
best_offset = offset; best_offset = offset;
} }
} }
offset += it->m_size; offset += it->m_size;
} }
if (best_fit == m_node_list.end()) if (best_fit == m_node_list.end())
{ {
throw bad_alloc(); throw bad_alloc();
} }
if (min_delta == 0) if (min_delta == 0)
{ {
// exact fit // exact fit
best_fit->m_state = block_state::ALLOCATED; best_fit->m_state = block_state::ALLOCATED;
} }
else else
{ {
m_node_list.insert(best_fit, node{size, block_state::ALLOCATED}); m_node_list.insert(best_fit, node{size, block_state::ALLOCATED});
best_fit->m_size -= size; best_fit->m_size -= size;
} }
m_max_allocated = std::max(m_max_allocated, best_offset + size); m_max_allocated = std::max(m_max_allocated, best_offset + size);
return best_offset; return best_offset;
} }
size_t pass::MemoryManager::first_fit(size_t size) size_t pass::MemoryManager::first_fit(size_t size)
{ {
size = align(size, m_alignment); size = align(size, m_alignment);
size_t offset = 0; size_t offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (it->m_state == block_state::FREE && it->m_size >= size) if (it->m_state == block_state::FREE && it->m_size >= size)
{ {
if (it->m_size > size) if (it->m_size > size)
{ {
m_node_list.insert(it, node{size, block_state::ALLOCATED}); m_node_list.insert(it, node{size, block_state::ALLOCATED});
it->m_size -= size; it->m_size -= size;
} }
else else
{ {
// exact fit // exact fit
it->m_state = block_state::ALLOCATED; it->m_state = block_state::ALLOCATED;
} }
found = true; found = true;
break; break;
} }
offset += it->m_size; offset += it->m_size;
} }
if (!found) if (!found)
{ {
throw bad_alloc(); throw bad_alloc();
} }
m_max_allocated = std::max(m_max_allocated, offset + size); m_max_allocated = std::max(m_max_allocated, offset + size);
return offset; return offset;
} }
void pass::MemoryManager::free(size_t offset) void pass::MemoryManager::free(size_t offset)
{ {
size_t search_offset = 0; size_t search_offset = 0;
bool found = false; bool found = false;
for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it) for (auto it=m_node_list.begin(); it != m_node_list.end(); ++it)
{ {
if (offset == search_offset) if (offset == search_offset)
{ {
list<node>::iterator it_next = std::next(it); list<node>::iterator it_next = std::next(it);
if (it == m_node_list.begin()) if (it == m_node_list.begin())
{ {
// free the first node in the list // free the first node in the list
it->m_state = block_state::FREE; it->m_state = block_state::FREE;
} }
else else
{ {
// node has predecessor // node has predecessor
list<node>::iterator it_prev = std::prev(it); list<node>::iterator it_prev = std::prev(it);
if (it_prev->m_state == block_state::FREE) if (it_prev->m_state == block_state::FREE)
{ {
it->m_size += it_prev->m_size; it->m_size += it_prev->m_size;
m_node_list.erase(it_prev); m_node_list.erase(it_prev);
} }
} }
if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE) if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE)
{ {
// join this node with next // join this node with next
it->m_size += it_next->m_size; it->m_size += it_next->m_size;
m_node_list.erase(it_next); m_node_list.erase(it_next);
} }
it->m_state = block_state::FREE; it->m_state = block_state::FREE;
found = true; found = true;
break; break;
} }
search_offset += it->m_size; search_offset += it->m_size;
} }
if (!found) if (!found)
{ {
throw runtime_error("bad free"); throw runtime_error("bad free");
} }
} }
void pass::MemoryManager::dump(std::ostream& out) void pass::MemoryManager::dump(std::ostream& out)
{ {
for (const node& n : m_node_list) for (const node& n : m_node_list)
{ {
out << "size=" << n.m_size << ", "; out << "size=" << n.m_size << ", ";
out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED"); out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED");
out << "\n"; out << "\n";
} }
} }
size_t pass::MemoryManager::align(size_t size, size_t alignment) size_t pass::MemoryManager::align(size_t size, size_t alignment)
{ {
if (size == 0) if (size == 0)
{ {
size = alignment; size = alignment;
} }
else else
{ {
auto remainder = size % alignment; auto remainder = size % alignment;
if (remainder > 0) if (remainder > 0)
{ {
size += (alignment - remainder); size += (alignment - remainder);
} }
} }
return size; return size;
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
#include <limits> #include <limits>
#include <list> #include <list>
#include <sstream> #include <sstream>
#include "call_pass.hpp" #include "call_pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class MemoryLayout; class MemoryLayout;
class MemoryNode; class MemoryNode;
class MemoryManager; class MemoryManager;
} }
class Node; class Node;
} }
class ngraph::pass::MemoryLayout : public CallBase class ngraph::pass::MemoryLayout : public CallBase
{ {
public: public:
virtual bool run_on_call_list(std::list<Node*>&) override; virtual bool run_on_call_list(std::list<Node*>&) override;
void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override; void check_dependencies(const std::vector<std::shared_ptr<CallBase>>&) const override;
private: private:
}; };
class ngraph::pass::MemoryManager class ngraph::pass::MemoryManager
{ {
public: public:
enum class block_state enum class block_state
{ {
FREE, FREE,
ALLOCATED ALLOCATED
}; };
enum class allocation_scheme enum class allocation_scheme
{ {
FIRST_FIT, FIRST_FIT,
BEST_FIT BEST_FIT
}; };
class node class node
{ {
public: public:
node(size_t size, block_state state); node(size_t size, block_state state);
bool is_free() const { return m_state == block_state::FREE; } bool is_free() const { return m_state == block_state::FREE; }
size_t m_size; size_t m_size;
block_state m_state; block_state m_state;
}; };
MemoryManager(size_t alignment=1); MemoryManager(size_t alignment=1);
// memory_manager& alignment(size_t a); // memory_manager& alignment(size_t a);
size_t allocate(size_t size); size_t allocate(size_t size);
void free(size_t offset); void free(size_t offset);
void dump(std::ostream&); void dump(std::ostream&);
static size_t align(size_t x, size_t alignment); static size_t align(size_t x, size_t alignment);
std::list<node>::iterator begin() { return m_node_list.begin(); } std::list<node>::iterator begin() { return m_node_list.begin(); }
std::list<node>::iterator end() { return m_node_list.end(); } std::list<node>::iterator end() { return m_node_list.end(); }
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); } std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
std::list<node>::const_iterator end() const { return m_node_list.cend(); } std::list<node>::const_iterator end() const { return m_node_list.cend(); }
const std::list<node>& get_node_list() const { return m_node_list; } const std::list<node>& get_node_list() const { return m_node_list; }
size_t max_allocated() const { return m_max_allocated; } size_t max_allocated() const { return m_max_allocated; }
private: private:
size_t first_fit(size_t size); size_t first_fit(size_t size);
size_t best_fit(size_t size); size_t best_fit(size_t size);
std::list<node> m_node_list; std::list<node> m_node_list;
size_t m_alignment; size_t m_alignment;
allocation_scheme m_scheme; allocation_scheme m_scheme;
size_t m_max_allocated; size_t m_max_allocated;
}; };
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "pass.hpp" #include "pass.hpp"
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class Base; class Base;
} }
} }
class ngraph::pass::Base class ngraph::pass::Base
{ {
public: public:
private: private:
}; };
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <sstream> #include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "propagate_types.hpp" #include "propagate_types.hpp"
using namespace std; using namespace std;
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <unordered_map> #include <unordered_map>
#include "log.hpp" #include "log.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "pass/topological_sort.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include "ngraph/pass/tree_pass.hpp" #include "pass/tree_pass.hpp"
namespace ngraph namespace ngraph
{ {
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "tree_pass.hpp" #include "tree_pass.hpp"
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once #pragma once
#include <list> #include <list>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "pass.hpp" #include "pass.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class TreeBase; class TreeBase;
} }
class Node; class Node;
} }
class ngraph::pass::TreeBase : public Base class ngraph::pass::TreeBase : public Base
{ {
public: public:
virtual ~TreeBase() {} virtual ~TreeBase() {}
// return true if changes were made to the tree // return true if changes were made to the tree
virtual bool run_on_tree(std::shared_ptr<Node>) = 0; virtual bool run_on_tree(std::shared_ptr<Node>) = 0;
virtual bool call_graph_produced() const { return false; } virtual bool call_graph_produced() const { return false; }
virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); } virtual std::list<Node*> get_call_graph() const { return std::list<Node*>(); }
// derived class throws exception if its dependencies have not been met // derived class throws exception if its dependencies have not been met
virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {} virtual void check_dependencies(const std::vector<std::shared_ptr<TreeBase>>&) const {}
private: private:
std::list<Node*> m_sorted_list; std::list<Node*> m_sorted_list;
}; };
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <fstream> #include <fstream>
#include "visualize_tree.hpp" #include "visualize_tree.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <string> #include <string>
#include <set> #include <set>
#include "ngraph/pass/tree_pass.hpp" #include "pass/tree_pass.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <algorithm> #include <algorithm>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/runtime/tensor_view.hpp" #include "runtime/tensor_view.hpp"
#include "ngraph/function.hpp" #include "function.hpp"
#include "ngraph/runtime/instruction.hpp" #include "runtime/instruction.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp" #include "runtime/eigen/tensor_view.hpp"
#include "ngraph/runtime/instruction.hpp" #include "runtime/instruction.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -18,17 +18,17 @@ ...@@ -18,17 +18,17 @@
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include "ngraph/descriptor/input.hpp" #include "descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "descriptor/output.hpp"
#include "ngraph/function.hpp" #include "function.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/ops/add.hpp" #include "ops/add.hpp"
#include "ngraph/ops/multiply.hpp" #include "ops/multiply.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/add.hpp" #include "runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/external_function.hpp" #include "runtime/eigen/external_function.hpp"
#include "ngraph/runtime/eigen/multiply.hpp" #include "runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/return.hpp" #include "runtime/eigen/return.hpp"
using namespace std; using namespace std;
using namespace ngraph::runtime::eigen; using namespace ngraph::runtime::eigen;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include "ngraph/function.hpp" #include "function.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp" #include "runtime/instruction.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp" #include "runtime/instruction.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include <Eigen/Dense> #include <Eigen/Dense>
#include <vector> #include <vector>
#include "ngraph/shape.hpp" #include "shape.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "runtime/tensor_view.hpp"
#include "ngraph/descriptor/tensor_view.hpp" #include "descriptor/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "ngraph/shape.hpp" #include "shape.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include "ngraph/element_type.hpp" #include "element_type.hpp"
#include "log.hpp" #include "log.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include "ngraph/except.hpp" #include "except.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "log.hpp" #include "log.hpp"
#include "util.hpp" #include "util.hpp"
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/element_type.hpp" #include "types/element_type.hpp"
#include "ngraph/shape.hpp" #include "shape.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_set> #include <unordered_set>
#include "util.hpp" #include "util.hpp"
#include "ngraph/node.hpp" #include "node.hpp"
#include "log.hpp" #include "log.hpp"
using namespace std; using namespace std;
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <fstream> #include <fstream>
#include <list> #include <list>
#include "ngraph/node.hpp" #include "node.hpp"
#include "ngraph/visualize.hpp" #include "visualize.hpp"
#include "util.hpp" #include "util.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include <memory> #include <memory>
using namespace std; using namespace std;
......
...@@ -18,6 +18,6 @@ ...@@ -18,6 +18,6 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/element_type.hpp" #include "types/element_type.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include <memory> #include <memory>
using namespace std; using namespace std;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -19,15 +19,15 @@ ...@@ -19,15 +19,15 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/pass/liveness.hpp" #include "pass/liveness.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp" #include "pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "pass/topological_sort.hpp"
#include "ngraph/pass/liveness.hpp" #include "pass/liveness.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "pass/visualize_tree.hpp"
#include "ngraph/pass/dump_sorted.hpp" #include "pass/dump_sorted.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
#include "log.hpp" #include "log.hpp"
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp" #include "pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "pass/topological_sort.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "pass/memory_layout.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
using namespace ngraph; using namespace ngraph;
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <algorithm> #include <algorithm>
#include "test_tools.hpp" #include "test_tools.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "util.hpp" #include "util.hpp"
using namespace std; using namespace std;
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "pass/topological_sort.hpp"
#include "ngraph/visualize.hpp" #include "visualize.hpp"
#include "util.hpp" #include "util.hpp"
#include "log.hpp" #include "log.hpp"
#include "test_tools.hpp" #include "test_tools.hpp"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph.hpp"
#include <memory> #include <memory>
using namespace std; using namespace std;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment