Unverified Commit 21613f88 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Some type info cleanups/improvements (#4043)

* Some type info cleanups/improvements

* Missed change
parent 46ff13c7
......@@ -602,6 +602,7 @@ set (SRC
type/float16.cpp
type/float16.hpp
type/element_type.cpp
type.cpp
type.hpp
util.cpp
util.hpp
......
......@@ -17,8 +17,8 @@
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include <unordered_map>
#include "ngraph/ngraph_visibility.hpp"
......@@ -32,7 +32,7 @@ namespace ngraph
{
public:
using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::map<decltype(BASE_TYPE::type_info), Factory>;
using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE>
......@@ -42,7 +42,7 @@ namespace ngraph
}
/// \brief Register a custom factory for type_info
void register_factory(const decltype(BASE_TYPE::type_info) & type_info, Factory factory)
void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[type_info] = factory;
......@@ -63,7 +63,7 @@ namespace ngraph
}
/// \brief Check to see if a factory is registered
bool has_factory(const decltype(BASE_TYPE::type_info) & info)
bool has_factory(const typename BASE_TYPE::type_info_t& info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end();
......@@ -77,7 +77,7 @@ namespace ngraph
}
/// \brief Create an instance for type_info
BASE_TYPE* create(const decltype(BASE_TYPE::type_info) & type_info)
BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info)
{
std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info);
......
......@@ -32,8 +32,6 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo Node::type_info;
atomic<size_t> Node::m_next_instance_id(0);
Node::Node(size_t output_size)
......
......@@ -92,8 +92,6 @@ namespace ngraph
/// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
using NodeTypeInfo = DiscreteTypeInfo;
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// or a (possibly empty) tuple of values.
......@@ -120,6 +118,8 @@ namespace ngraph
// Called in constructors during transition
void constructor_validate_and_infer_types();
using type_info_t = DiscreteTypeInfo;
protected:
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
......@@ -158,8 +158,6 @@ namespace ngraph
void safe_delete(NodeVector& nodes, bool recurse);
public:
static constexpr NodeTypeInfo type_info{"Node", 0};
virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
......@@ -181,17 +179,8 @@ namespace ngraph
/// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet.
virtual const NodeTypeInfo& get_type_info() const = 0;
virtual const char* get_type_name() const
{
auto& info = get_type_info();
if (is_type<Node>(this))
{
// Transitional definition
return description().c_str();
}
return info.name;
}
virtual const type_info_t& get_type_info() const = 0;
const char* get_type_name() const { return get_type_info().name; }
/// Sets/replaces the arguments with new arguments.
void set_arguments(const NodeVector& arguments);
/// Sets/replaces the arguments with new arguments.
......@@ -230,7 +219,7 @@ namespace ngraph
/// graph against the graph.
bool is_same_op_type(const std::shared_ptr<Node>& node) const
{
return description() == node->description();
return get_type_info() == node->get_type_info();
}
/// \brief Marks an input as being relevant or irrelevant to the output shapes of this
......@@ -533,6 +522,8 @@ namespace ngraph
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
};
using NodeTypeInfo = Node::type_info_t;
template <typename NodeType>
class Input
{
......
......@@ -28,7 +28,8 @@
#include "ngraph/node.hpp"
using visualize_tree_ops_map_t =
std::unordered_map<std::type_index, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
std::unordered_map<ngraph::Node::type_info_t,
std::function<void(const ngraph::Node&, std::ostream& ss)>>;
namespace ngraph
{
......
......@@ -31,8 +31,6 @@
using namespace ngraph;
using namespace std;
#define TI(x) type_index(typeid(x))
//
// As we are visualizing the graph, we will make some tweaks to the generated dot file to make
// routing more tractable for Graphviz as well as (hopefully) more legible for the user.
......@@ -260,8 +258,9 @@ void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node,
size_t& fake_node_ctr)
{
size_t arg_index = 0;
for (auto arg : node->get_arguments())
for (auto input_value : node->input_values())
{
auto arg = input_value.get_node_shared_ptr();
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
if (is_type<ngraph::op::Constant>(arg) || is_type<ngraph::op::Parameter>(arg))
{
......@@ -388,11 +387,10 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
: node->get_element_type().c_type_string());
}
const Node& n = *node;
auto eh = m_ops_to_details.find(TI(n));
auto eh = m_ops_to_details.find(node->get_type_info());
if (eh != m_ops_to_details.end())
{
eh->second(n, label);
eh->second(*node, label);
}
label << "\"";
attributes.push_back(label.str());
......
......@@ -61,8 +61,7 @@ protected:
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
std::unordered_map<std::type_index, std::function<void(const Node&, std::ostream& ss)>>
m_ops_to_details;
visualize_tree_ops_map_t m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
......
......@@ -14,11 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include "cpu_visualize_tree.hpp"
#include <string>
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_visualize_tree.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
......@@ -26,8 +27,6 @@ using namespace mkldnn;
using namespace ngraph;
using namespace std;
#define TI(x) std::type_index(typeid(x))
static void visualize_layout_format(const Node& node, ostream& ss)
{
try
......@@ -90,8 +89,8 @@ namespace ngraph
const visualize_tree_ops_map_t& get_visualize_tree_ops_map()
{
const static visualize_tree_ops_map_t vtom{
{TI(runtime::cpu::op::ConvertLayout), visualize_layout_format},
{TI(ngraph::op::Reshape), visualize_layout_format}};
{runtime::cpu::op::ConvertLayout::type_info, visualize_layout_format},
{ngraph::op::Reshape::type_info, visualize_layout_format}};
return vtom;
}
}
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/type.hpp"
#include "ngraph/util.hpp"
namespace std
{
size_t std::hash<ngraph::DiscreteTypeInfo>::operator()(const ngraph::DiscreteTypeInfo& k) const
{
size_t name_hash = hash<decltype(k.name)>()(k.name);
size_t version_hash = hash<decltype(k.version)>()(k.version);
return ngraph::hash_combine(vector<size_t>{name_hash, version_hash});
}
}
......@@ -18,9 +18,11 @@
#include <cstdint>
#include <cstring>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ngraph/ngraph_visibility.hpp"
......@@ -102,3 +104,12 @@ namespace ngraph
: std::shared_ptr<Type>();
}
}
namespace std
{
template <>
struct hash<ngraph::DiscreteTypeInfo>
{
size_t operator()(const ngraph::DiscreteTypeInfo& k) const;
};
}
......@@ -30,6 +30,7 @@
#include <typeinfo>
#include <unordered_map>
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment