Commit 71616162 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Add info for cpu ops in VisualizeTree (#1622)

* inject details into visualize_tree pass

* address bob's feedback

* revert back to map
parent 683822ef
......@@ -135,6 +135,7 @@ void ngraph::pass::Manager::run_passes(shared_ptr<Function> func, bool transitiv
if (m_visualize)
{
pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext());
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
vt.run_on_module(fs);
}
......
......@@ -16,10 +16,19 @@
#pragma once
#include <functional>
#include <initializer_list>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <utility>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
using visualize_tree_ops_map_t =
std::unordered_map<std::type_index, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
namespace ngraph
{
......@@ -41,6 +50,17 @@ public:
m_function_list.insert(m_function_list.begin(), collection.begin(), collection.end());
}
void set_visualize_tree_ops_map(const visualize_tree_ops_map_t& ops_map)
{
m_visualize_tree_ops_map = ops_map;
}
const visualize_tree_ops_map_t& get_visualize_tree_ops_map()
{
return m_visualize_tree_ops_map;
}
private:
std::vector<std::shared_ptr<Function>> m_function_list;
visualize_tree_ops_map_t m_visualize_tree_ops_map;
};
......@@ -27,6 +27,8 @@
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<ngraph::Function>>& functions)
{
for (shared_ptr<Function> f : functions)
......@@ -101,7 +103,8 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
ss << " label=\"" << node->get_name();
if (std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES") != nullptr)
static const auto nvtos = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
if (nvtos != nullptr)
{
// The shapes of the Outputs of a multi-output op
// will be printed for its corresponding `GetOutputElement`s
......@@ -109,6 +112,13 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
: vector_to_string(node->get_shape()));
}
const Node& n = *node;
auto eh = m_ops_to_details.find(TI(n));
if (eh != m_ops_to_details.end())
{
eh->second(n, ss);
}
ss << " \"]\n";
return ss.str();
......
......@@ -16,10 +16,16 @@
#pragma once
#include <functional>
#include <set>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <utility>
#include "ngraph/pass/manager_state.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
......@@ -37,7 +43,7 @@ public:
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
static std::string get_file_ext();
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
private:
std::string add_attributes(std::shared_ptr<Node> node);
std::string get_attributes(std::shared_ptr<Node> node);
......@@ -46,4 +52,6 @@ private:
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
std::unordered_map<std::type_index, std::function<void(const Node&, std::ostream& ss)>>
m_ops_to_details;
};
......@@ -26,6 +26,7 @@ set(SRC
cpu_tensor_view_wrapper.cpp
cpu_tensor_view.cpp
cpu_tracing.cpp
cpu_visualize_tree.cpp
quantization_util.cpp
builder/add.cpp
builder/allreduce.cpp
......
......@@ -134,6 +134,7 @@
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_tracing.hpp"
#include "ngraph/runtime/cpu/cpu_visualize_tree.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp"
......@@ -393,6 +394,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.get_state().set_visualize_tree_ops_map(runtime::cpu::get_visualize_tree_ops_map());
}
void runtime::cpu::CPU_ExternalFunction::compile()
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "cpu_visualize_tree.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
using namespace mkldnn;
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
static void visualize_convert_layout(const Node& node, ostream& ss)
{
auto input_desc = runtime::cpu::mkldnn_utils::get_input_mkldnn_md(&node, 0);
auto result_desc = runtime::cpu::mkldnn_utils::get_output_mkldnn_md(&node, 0);
ss << "in=" << runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
static_cast<mkldnn::memory::format>(input_desc.data.format));
ss << " out=" << runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
static_cast<mkldnn::memory::format>(result_desc.data.format));
ss << " ";
}
namespace ngraph
{
namespace runtime
{
namespace cpu
{
const visualize_tree_ops_map_t& get_visualize_tree_ops_map()
{
const static visualize_tree_ops_map_t vtom{
{TI(runtime::cpu::op::ConvertLayout), visualize_convert_layout}};
return vtom;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
const visualize_tree_ops_map_t& get_visualize_tree_ops_map();
}
}
}
......@@ -131,6 +131,8 @@ static const std::map<memory::format, const std::string> s_mkldnn_format_string_
{memory::format::ldsnc, "memory::format::ldsnc"},
{memory::format::ldigo, "memory::format::ldigo"},
{memory::format::ldgo, "memory::format::ldgo"},
{memory::format::ldgo, "memory::format::Goihw8g"},
{memory::format::ldgo, "memory::format::Goihw16g"},
};
static const std::set<memory::format> s_filter_formats{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment