Commit 32fb97d1 authored by Scott Cyphers's avatar Scott Cyphers

Review comments, fix as_ methods

parent 321f32f5
......@@ -2,8 +2,8 @@
## Op Definition
* Every Op class must declare a `static constexpr NodeTypeInfo type_info{name, version}` in the class definition and define it in the .cpp file. See any op definition for an example.
* The boolean node method `has_type<T>` is for testing if a node is the op `T`.
* `T as_type<T>()` will upcast `Node` to an explicit op class if it is of class `T`, or `nullptr` if it is not.
* The boolean node method `is_type<T>` is for testing if a node is the op `T`.
* `T as_type_ptr<T>()` and `T as_typr<T>()` will upcast `Node` to an explicit op class if it is of class `T`, or `nullptr` if it is not.
## Passes
* `LikeReplacement` pass must be run by all transformers.
......
......@@ -93,7 +93,7 @@ std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i, bool for_
{
for (auto in : output(i).get_target_inputs())
{
if (in.get_node()->has_type<op::GetOutputElement>())
if (in.get_node()->is_type<op::GetOutputElement>())
{
return in.get_node()->shared_from_this();
}
......@@ -105,7 +105,7 @@ std::shared_ptr<Node>
Node::copy_with_new_inputs(const OutputVector& inputs,
const std::vector<std::shared_ptr<Node>>& control_dependencies) const
{
bool for_get_output_element = has_type<op::GetOutputElement>();
bool for_get_output_element = is_type<op::GetOutputElement>();
NodeVector args;
for (const Output<Node>& input : inputs)
{
......@@ -261,7 +261,7 @@ const std::deque<descriptor::Output>& Node::get_outputs() const
bool Node::is_parameter() const
{
return dynamic_cast<const op::Parameter*>(this) != nullptr;
return is_type<op::Parameter>();
}
bool Node::is_output() const
......
......@@ -91,7 +91,7 @@ namespace ngraph
/// or a (possibly empty) tuple of values.
class Node : public std::enable_shared_from_this<Node>
{
static constexpr NodeTypeInfo type_info{"Node_0", 0};
static constexpr NodeTypeInfo type_info{"Node", 0};
// For access to generate_adjoints.
friend class autodiff::Adjoints;
......@@ -150,33 +150,49 @@ namespace ngraph
/// Tests if a node is of op type T
template <typename NodeType>
bool has_type() const
bool is_type() const
{
return &get_type_info() == &NodeType::type_info;
}
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise;
template <typename NodeType>
std::shared_ptr<NodeType> as_type()
std::shared_ptr<NodeType> as_type_ptr()
{
return has_type<NodeType>() ? shared_from_this() : std::shared_ptr<NodeType>();
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this())
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise;
template <typename NodeType>
std::shared_ptr<const NodeType> as_type() const
std::shared_ptr<const NodeType> as_type_ptr() const
{
return has_type<NodeType>() ? shared_from_this() : std::shared_ptr<NodeType>();
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this())
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
NodeType* as_type()
{
return is_type<NodeType>() ? static_cast<NodeType*>(this) : nullptr;
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
const NodeType* as_type() const
{
return is_type<NodeType>() ? static_cast<const NodeType*>(this) : nullptr;
}
/// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, return's a dummy type_info for Node if the class
/// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet.
virtual const NodeTypeInfo& get_type_info() const { return type_info; }
virtual const char* get_type_name() const
{
auto& info = get_type_info();
if (has_type<Node>())
if (is_type<Node>())
{
// Transitional definition
return description().c_str();
......
......@@ -32,7 +32,7 @@ namespace ngraph
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an unitialized addition operation
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
......
......@@ -30,7 +30,7 @@ namespace ngraph
static constexpr NodeTypeInfo type_info{"Subtract", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Subtract() = default;
/// \brief Constructs an subtraction operation.
/// \brief Constructs a subtraction operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
......
......@@ -163,7 +163,7 @@ static std::string label_edge(const std::shared_ptr<Node>& src,
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
{
size_t output = 0;
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(dst))
if (auto goe = dst->as_type<op::GetOutputElement>())
{
output = goe->get_as_output().get_index();
}
......@@ -263,7 +263,7 @@ void pass::VisualizeTree::add_node_arguments(shared_ptr<Node> node,
for (auto arg : node->get_arguments())
{
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
if (arg->has_type<ngraph::op::Constant>() || arg->has_type<ngraph::op::Parameter>())
if (arg->is_type<ngraph::op::Constant>() || arg->is_type<ngraph::op::Parameter>())
{
auto clone_name = "CLONE_" + to_string(fake_node_ctr);
auto color = (arg->description() == "Parameter" ? "blue" : "black");
......@@ -416,7 +416,7 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{
rc += "\\n" + node->get_name();
}
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
if (auto ck = node->as_type<ngraph::op::CompiledKernel>())
{
rc += "\\n{";
// add sub-graph node names
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment