Unverified Commit 21613f88 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Some type info cleanups/improvements (#4043)

* Some type info cleanups/improvements

* Missed change
parent 46ff13c7
...@@ -602,6 +602,7 @@ set (SRC ...@@ -602,6 +602,7 @@ set (SRC
type/float16.cpp type/float16.cpp
type/float16.hpp type/float16.hpp
type/element_type.cpp type/element_type.cpp
type.cpp
type.hpp type.hpp
util.cpp util.cpp
util.hpp util.hpp
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <map>
#include <mutex> #include <mutex>
#include <unordered_map>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
{ {
public: public:
using Factory = std::function<BASE_TYPE*()>; using Factory = std::function<BASE_TYPE*()>;
using FactoryMap = std::map<decltype(BASE_TYPE::type_info), Factory>; using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
// \brief Get the default factory for DERIVED_TYPE. Specialize as needed. // \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
template <typename DERIVED_TYPE> template <typename DERIVED_TYPE>
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
} }
/// \brief Register a custom factory for type_info /// \brief Register a custom factory for type_info
void register_factory(const decltype(BASE_TYPE::type_info) & type_info, Factory factory) void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
m_factory_map[type_info] = factory; m_factory_map[type_info] = factory;
...@@ -63,7 +63,7 @@ namespace ngraph ...@@ -63,7 +63,7 @@ namespace ngraph
} }
/// \brief Check to see if a factory is registered /// \brief Check to see if a factory is registered
bool has_factory(const decltype(BASE_TYPE::type_info) & info) bool has_factory(const typename BASE_TYPE::type_info_t& info)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
return m_factory_map.find(info) != m_factory_map.end(); return m_factory_map.find(info) != m_factory_map.end();
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
} }
/// \brief Create an instance for type_info /// \brief Create an instance for type_info
BASE_TYPE* create(const decltype(BASE_TYPE::type_info) & type_info) BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info)
{ {
std::lock_guard<std::mutex> guard(get_registry_mutex()); std::lock_guard<std::mutex> guard(get_registry_mutex());
auto it = m_factory_map.find(type_info); auto it = m_factory_map.find(type_info);
......
...@@ -32,8 +32,6 @@ ...@@ -32,8 +32,6 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo Node::type_info;
atomic<size_t> Node::m_next_instance_id(0); atomic<size_t> Node::m_next_instance_id(0);
Node::Node(size_t output_size) Node::Node(size_t output_size)
......
...@@ -92,8 +92,6 @@ namespace ngraph ...@@ -92,8 +92,6 @@ namespace ngraph
/// Alias useful for cloning /// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>; using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
using NodeTypeInfo = DiscreteTypeInfo;
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// or a (possibly empty) tuple of values. /// or a (possibly empty) tuple of values.
...@@ -120,6 +118,8 @@ namespace ngraph ...@@ -120,6 +118,8 @@ namespace ngraph
// Called in constructors during transition // Called in constructors during transition
void constructor_validate_and_infer_types(); void constructor_validate_and_infer_types();
using type_info_t = DiscreteTypeInfo;
protected: protected:
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args( std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec()); const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
...@@ -158,8 +158,6 @@ namespace ngraph ...@@ -158,8 +158,6 @@ namespace ngraph
void safe_delete(NodeVector& nodes, bool recurse); void safe_delete(NodeVector& nodes, bool recurse);
public: public:
static constexpr NodeTypeInfo type_info{"Node", 0};
virtual ~Node(); virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; } virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
...@@ -181,17 +179,8 @@ namespace ngraph ...@@ -181,17 +179,8 @@ namespace ngraph
/// Returns the NodeTypeInfo for the node's class. /// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, returns a dummy type_info for Node if the class /// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet. /// has not been updated yet.
virtual const NodeTypeInfo& get_type_info() const = 0; virtual const type_info_t& get_type_info() const = 0;
virtual const char* get_type_name() const const char* get_type_name() const { return get_type_info().name; }
{
auto& info = get_type_info();
if (is_type<Node>(this))
{
// Transitional definition
return description().c_str();
}
return info.name;
}
/// Sets/replaces the arguments with new arguments. /// Sets/replaces the arguments with new arguments.
void set_arguments(const NodeVector& arguments); void set_arguments(const NodeVector& arguments);
/// Sets/replaces the arguments with new arguments. /// Sets/replaces the arguments with new arguments.
...@@ -230,7 +219,7 @@ namespace ngraph ...@@ -230,7 +219,7 @@ namespace ngraph
/// graph against the graph. /// graph against the graph.
bool is_same_op_type(const std::shared_ptr<Node>& node) const bool is_same_op_type(const std::shared_ptr<Node>& node) const
{ {
return description() == node->description(); return get_type_info() == node->get_type_info();
} }
/// \brief Marks an input as being relevant or irrelevant to the output shapes of this /// \brief Marks an input as being relevant or irrelevant to the output shapes of this
...@@ -533,6 +522,8 @@ namespace ngraph ...@@ -533,6 +522,8 @@ namespace ngraph
std::map<std::string, std::shared_ptr<Variant>> m_rt_info; std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
}; };
using NodeTypeInfo = Node::type_info_t;
template <typename NodeType> template <typename NodeType>
class Input class Input
{ {
......
...@@ -28,7 +28,8 @@ ...@@ -28,7 +28,8 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
using visualize_tree_ops_map_t = using visualize_tree_ops_map_t =
std::unordered_map<std::type_index, std::function<void(const ngraph::Node&, std::ostream& ss)>>; std::unordered_map<ngraph::Node::type_info_t,
std::function<void(const ngraph::Node&, std::ostream& ss)>>;
namespace ngraph namespace ngraph
{ {
......
...@@ -31,8 +31,6 @@ ...@@ -31,8 +31,6 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
#define TI(x) type_index(typeid(x))
// //
// As we are visualizing the graph, we will make some tweaks to the generated dot file to make // As we are visualizing the graph, we will make some tweaks to the generated dot file to make
// routing more tractable for Graphviz as well as (hopefully) more legible for the user. // routing more tractable for Graphviz as well as (hopefully) more legible for the user.
...@@ -260,8 +258,9 @@ void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node, ...@@ -260,8 +258,9 @@ void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node,
size_t& fake_node_ctr) size_t& fake_node_ctr)
{ {
size_t arg_index = 0; size_t arg_index = 0;
for (auto arg : node->get_arguments()) for (auto input_value : node->input_values())
{ {
auto arg = input_value.get_node_shared_ptr();
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]); size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
if (is_type<ngraph::op::Constant>(arg) || is_type<ngraph::op::Parameter>(arg)) if (is_type<ngraph::op::Constant>(arg) || is_type<ngraph::op::Parameter>(arg))
{ {
...@@ -388,11 +387,10 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node) ...@@ -388,11 +387,10 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
: node->get_element_type().c_type_string()); : node->get_element_type().c_type_string());
} }
const Node& n = *node; auto eh = m_ops_to_details.find(node->get_type_info());
auto eh = m_ops_to_details.find(TI(n));
if (eh != m_ops_to_details.end()) if (eh != m_ops_to_details.end())
{ {
eh->second(n, label); eh->second(*node, label);
} }
label << "\""; label << "\"";
attributes.push_back(label.str()); attributes.push_back(label.str());
......
...@@ -61,8 +61,7 @@ protected: ...@@ -61,8 +61,7 @@ protected:
std::stringstream m_ss; std::stringstream m_ss;
std::string m_name; std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes; std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
std::unordered_map<std::type_index, std::function<void(const Node&, std::ostream& ss)>> visualize_tree_ops_map_t m_ops_to_details;
m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr; node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only; bool m_dot_only;
static const int max_jump_distance; static const int max_jump_distance;
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "cpu_visualize_tree.hpp"
#include <string> #include <string>
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp" #include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/cpu_visualize_tree.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp" #include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
...@@ -26,8 +27,6 @@ using namespace mkldnn; ...@@ -26,8 +27,6 @@ using namespace mkldnn;
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
#define TI(x) std::type_index(typeid(x))
static void visualize_layout_format(const Node& node, ostream& ss) static void visualize_layout_format(const Node& node, ostream& ss)
{ {
try try
...@@ -90,8 +89,8 @@ namespace ngraph ...@@ -90,8 +89,8 @@ namespace ngraph
const visualize_tree_ops_map_t& get_visualize_tree_ops_map() const visualize_tree_ops_map_t& get_visualize_tree_ops_map()
{ {
const static visualize_tree_ops_map_t vtom{ const static visualize_tree_ops_map_t vtom{
{TI(runtime::cpu::op::ConvertLayout), visualize_layout_format}, {runtime::cpu::op::ConvertLayout::type_info, visualize_layout_format},
{TI(ngraph::op::Reshape), visualize_layout_format}}; {ngraph::op::Reshape::type_info, visualize_layout_format}};
return vtom; return vtom;
} }
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/type.hpp"
#include "ngraph/util.hpp"
namespace std
{
size_t std::hash<ngraph::DiscreteTypeInfo>::operator()(const ngraph::DiscreteTypeInfo& k) const
{
size_t name_hash = hash<decltype(k.name)>()(k.name);
size_t version_hash = hash<decltype(k.version)>()(k.version);
return ngraph::hash_combine(vector<size_t>{name_hash, version_hash});
}
}
...@@ -18,9 +18,11 @@ ...@@ -18,9 +18,11 @@
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
...@@ -102,3 +104,12 @@ namespace ngraph ...@@ -102,3 +104,12 @@ namespace ngraph
: std::shared_ptr<Type>(); : std::shared_ptr<Type>();
} }
} }
namespace std
{
template <>
struct hash<ngraph::DiscreteTypeInfo>
{
size_t operator()(const ngraph::DiscreteTypeInfo& k) const;
};
}
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment