Unverified Commit b408bbe6 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Refactor, cleanup type info to make it safer to use for non-ops (#3672)

parent 3aa4db1d
......@@ -39,6 +39,7 @@
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/placement.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type.hpp"
namespace ngraph
{
......@@ -81,20 +82,7 @@ namespace ngraph
/// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
struct TypeInfo
{
const char* name;
uint64_t version;
};
using NodeTypeInfo = TypeInfo;
/// Tests if a node is of op type T
template <typename NodeType, typename T>
bool is_type(T value)
{
return &value->get_type_info() == &NodeType::type_info;
}
using NodeTypeInfo = DiscreteTypeInfo;
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
......@@ -494,36 +482,6 @@ namespace ngraph
size_t m_placement_index = placement_invalid;
};
/// Casts a Node* to a NodeType* if it is of type NodeType, nullptr otherwise
template <typename NodeType>
NodeType* as_type(Node* node)
{
return is_type<NodeType>(node) ? static_cast<NodeType*>(node) : nullptr;
}
/// Casts a Node* to a NodePtr* if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
const NodeType* as_type(const Node* node)
{
return is_type<NodeType>(node) ? static_cast<const NodeType*>(node) : nullptr;
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<NodeType> as_type_ptr(std::shared_ptr<Node> node_ptr)
{
return is_type<NodeType>(node_ptr) ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<const NodeType> as_type_ptr(std::shared_ptr<const Node> node_ptr)
{
return is_type<NodeType>(node_ptr) ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// \brief A handle for one of a node's inputs.
template <typename NodeType>
class Input
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include <utility>
namespace ngraph
{
/// Supports three functions, is_type<Type>, as_type<Type>, and as_type_ptr<Type> for type-safe
/// dynamic conversions via static_cast/static_ptr_cast without using C++ RTTI.
/// Type must have a static constexpr type_info member and a virtual get_type_info() member that
/// returns a reference to its type_info member.
/// Type information for a type system without inheritance; instances have exactly one type not
/// related to any other type.
struct DiscreteTypeInfo
{
const char* name;
uint64_t version;
bool is_castable(const DiscreteTypeInfo& target_type) const { return this == &target_type; }
};
/// \brief Tests if value is a pointer/shared_ptr that can be statically cast to a
/// Type*/shared_ptr<Type>
template <typename Type, typename Value>
typename std::enable_if<
std::is_convertible<
decltype(std::declval<Value>()->get_type_info().is_castable(Type::type_info)),
bool>::value,
bool>::type
is_type(Value value)
{
return value->get_type_info().is_castable(Type::type_info);
}
/// Casts a Value* to a Type* if it is of type Type, nullptr otherwise
template <typename Type, typename Value>
typename std::enable_if<
std::is_convertible<decltype(static_cast<Type*>(std::declval<Value>())), Type*>::value,
Type*>::type
as_type(Value value)
{
return is_type<Type>(value) ? static_cast<Type*>(value) : nullptr;
}
/// Casts a std::shared_ptr<Value> to a std::shared_ptr<Type> if it is of type
/// Type, nullptr otherwise
template <typename Type, typename Value>
typename std::enable_if<
std::is_convertible<decltype(std::static_pointer_cast<Type>(std::declval<Value>())),
std::shared_ptr<Type>>::value,
std::shared_ptr<Type>>::type
as_type_ptr(Value value)
{
return is_type<Type>(value) ? std::static_pointer_cast<Type>(value)
: std::shared_ptr<Type>();
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment