/* // Copyright (c) 2018 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. */ /////////////////////////////////////////////////////////////////////////////////////////////////// #include "pass_manager.h" // ToDo: remove those include with the appropriate code below once we will have support for multiple outputs of a // primitive #include "batch_norm_inst.h" #include "max_unpooling_inst.h" #include "pooling_inst.h" #include <vector> #include <list> using namespace cldnn; // This pass optimizes out nodes which have no impact on outputs void trim_to_outputs::run(program_impl& p) { const size_t actual_nodes = p.get_processing_order().size(); if (!actual_nodes) // degenerated case but can happen return; if (p.get_outputs().size() == actual_nodes) return; // do backward bfs starting from all outputs std::list<const std::vector<program_node*>*> stack = {&(p.get_outputs())}; std::vector<program_node*> special_nodes; for (auto& node : p.get_processing_order()) { if (node->is_type<input_layout>() || // input layout may become disconnected during prior boxes calculations so // it may have not been marked at this place but we don't want to remove it node->is_type<max_unpooling>() || // ToDo: remove this after support for multi-outputs in primitives will // be implemented. node->is_type<batch_norm>() || (node->is_type<pooling>() && node->as<pooling>().get_primitive()->mode == pooling_mode::max_with_argmax)) special_nodes.push_back(node); } stack.push_back(&special_nodes); while (!stack.empty()) { auto nodes_list = stack.front(); stack.pop_front(); for (auto& node : *nodes_list) { if (!node->is_marked()) { node->mark(); if (!node->get_dependencies().empty()) stack.push_back(&node->get_dependencies()); } } } // all not-marked nodes should be removed std::list<program_node*> to_rem; for (auto& node : p.get_processing_order()) { if (!node->is_marked()) to_rem.push_back(node); } p.remove_nodes(to_rem); }