pass_liveness.cpp 3.71 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------

#include <memory>
#include <sstream>
#include <string>
#include <vector>

#include "gtest/gtest.h"

22 23
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
24
#include "ngraph/pass/assign_tensors.hpp"
25 26 27
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
28 29 30 31 32
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"

33 34 35
#include "test_tools.hpp"

using namespace std;
Bob Kimball's avatar
Bob Kimball committed
36
using namespace ngraph;
37 38
namespace ng = ngraph;

Bob Kimball's avatar
Bob Kimball committed
39
TEST(pass, liveness)
40
{
Bob Kimball's avatar
Bob Kimball committed
41 42 43 44
    string image = "liveness.png";
    string dump_file = "liveness.txt";
    pass::Manager pass_manager;

45 46 47 48 49 50
    pass_manager.register_pass<pass::VisualizeTree>(image);
    pass_manager.register_pass<pass::TopologicalSort>();
    pass_manager.register_pass<pass::PropagateTypes>();
    pass_manager.register_pass<pass::AssignTensors>();
    pass_manager.register_pass<pass::Liveness>();
    pass_manager.register_pass<pass::DumpSorted>(dump_file);
Bob Kimball's avatar
Bob Kimball committed
51

52
    shared_ptr<Function> func = make_test_graph();
53
    pass_manager.run_passes(func);
Robert Kimball's avatar
Robert Kimball committed
54
    auto sorted = func->get_ordered_ops();
Bob Kimball's avatar
Bob Kimball committed
55 56 57

    // for (const Node* node : sorted)
    // {
Yixing Lao's avatar
Yixing Lao committed
58
    //     NGRAPH_INFO << *node;
Bob Kimball's avatar
Bob Kimball committed
59 60
    //     for (const descriptor::Tensor* tensor : node->liveness_live_list)
    //     {
Yixing Lao's avatar
Yixing Lao committed
61
    //         NGRAPH_INFO << "    " << *tensor;
Bob Kimball's avatar
Bob Kimball committed
62 63 64
    //     }
    // }

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    // auto x = ng.variable(axes=[]).named('x');
    // auto y = ng.variable(axes=[]).named('y');
    // auto w1 = ng.variable(axes=[]).named('w1');
    // auto w2 = ng.variable(axes=[]).named('w2');

    // auto x2 = x * w1;
    // auto x3 = (x2 * w2).named('result');
    // auto cost = x3 - y;

    // auto dw1 = ng.deriv(cost, w1);
    // auto dw2 = ng.deriv(cost, w2);

    // auto upd1 = ng.assign(w1, w1 + dw1);
    // auto upd2 = ng.assign(w2, w2 + dw2);
    // auto seq_stuff = ng.sequential([upd1, upd2, x3]);

    // auto exc = ex.executor(seq_stuff);
    // return exc;

    // lg = LivenessGraph(exc.exop.ops)
    // lg.layout_memory()

    // for i, node in enumerate(lg.liveness_nodes):
    //     print i, node

    // for node in lg.liveness_nodes:
    //     for var1 in node.live_list:
    //         assert var1.buffer_pool_offset is not None
    //         for var2 in node.live_list:
    //             if var1 != var2:
    //                 if var1.buffer_pool_offset < var2.buffer_pool_offset:
    //                     assert var1.buffer_pool_offset + var1.size <= var2.buffer_pool_offset
    //                 else:
    //                     assert var2.buffer_pool_offset + var2.size <= var1.buffer_pool_offset

    // // for o in egraph.computations:
    // //     print o.values

    // print("max memory {}".format(lg.memory_footprint()))
    // print("worst case memory {}".format(lg.worst_case_memory_usage()))
    // print("memory efficiency {}".format(lg.memory_efficiency()))
    // // // print lg.liveness_json()
}