Unverified Commit 15ceedb7 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove some dynamic_ptr_casts (#3622)

* Remove some dynamic_ptr_casts

* More removals

* Review comments

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarAdam Procter <adam.m.procter@intel.com>
parent ecf7a396
......@@ -52,7 +52,7 @@ Function::Function(const OutputVector& results,
, m_unique_name("Function_" + to_string(m_instance_id))
{
if (std::any_of(results.cbegin(), results.cend(), [](Output<Node> n) {
return std::dynamic_pointer_cast<op::Result>(n.get_node_shared_ptr());
return as_type_ptr<op::Result>(n.get_node_shared_ptr());
}))
{
throw ngraph_error(
......@@ -76,7 +76,7 @@ Function::Function(const NodeVector& results,
, m_unique_name("Function_" + to_string(m_instance_id))
{
if (std::any_of(results.cbegin(), results.cend(), [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Result>(n);
return as_type_ptr<op::Result>(n);
}))
{
throw ngraph_error(
......
......@@ -306,7 +306,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
ResultVector cloned_results;
for (shared_ptr<Node> node : func.get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(node_map.at(node.get()));
auto result = as_type_ptr<op::Result>(node_map.at(node.get()));
if (!result)
{
throw ngraph_error("Results should be of type op::Result");
......@@ -316,7 +316,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func.get_parameters())
{
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.at(param.get())));
cloned_params.push_back(as_type_ptr<op::Parameter>(node_map.at(param.get())));
}
// create and return cloned function
......@@ -325,7 +325,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
{
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
......
......@@ -454,6 +454,12 @@ void Node::clear_control_dependents()
}
}
const op::AutoBroadcastSpec& Node::get_autob() const
{
static op::AutoBroadcastSpec s_spec;
return s_spec;
}
namespace ngraph
{
ostream& operator<<(ostream& out, const Node& node)
......
......@@ -56,6 +56,7 @@ namespace ngraph
namespace op
{
struct AutoBroadcastSpec;
class Constant;
} // namespace op
......@@ -157,37 +158,21 @@ namespace ngraph
{
return &get_type_info() == &NodeType::type_info;
}
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise;
template <typename NodeType>
std::shared_ptr<NodeType> as_type_ptr()
{
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this())
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise;
template <typename NodeType>
std::shared_ptr<const NodeType> as_type_ptr() const
{
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this())
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
NodeType* as_type()
{
return is_type<NodeType>() ? static_cast<NodeType*>(this) : nullptr;
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
const NodeType* as_type() const
{
return is_type<NodeType>() ? static_cast<const NodeType*>(this) : nullptr;
}
virtual bool is_unary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_comparison() const { return false; }
virtual bool is_binary_elementwise_logical() const { return false; }
/// \returns true if node supports autobroadcast operations
virtual bool supports_auto_broadcast() const { return false; }
/// \returns the autobroadcasr spec
virtual const op::AutoBroadcastSpec& get_autob() const;
/// \returns true if the node can decompose
virtual bool supports_decompose() const { return false; }
/// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
///
/// \return A vector of nodes comprising the sub-graph. The order of output
/// tensors must match the match output tensors of the FusedOp
virtual NodeVector decompose_op() const { return NodeVector(); }
/// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet.
......@@ -504,6 +489,36 @@ namespace ngraph
size_t m_placement_index = placement_invalid;
};
/// Casts a Node* to a NodeType* if it is of type NodeType, nullptr otherwise
template <typename NodeType>
NodeType* as_type(Node* node)
{
return node->template is_type<NodeType>() ? static_cast<NodeType*>(node) : nullptr;
}
/// Casts a Node* to a NodePtr* if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
const NodeType* as_type(const Node* node)
{
return node->template is_type<NodeType>() ? static_cast<const NodeType*>(node) : nullptr;
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<NodeType> as_type_ptr(std::shared_ptr<Node> node_ptr)
{
return node_ptr->template is_type<NodeType>() ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<const NodeType> as_type_ptr(std::shared_ptr<const Node> node_ptr)
{
return node_ptr->template is_type<NodeType>() ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// \brief A handle for one of a node's inputs.
template <typename NodeType>
class Input
......
......@@ -107,9 +107,9 @@ void op::DynReplaceSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(input_value(4).get_node_shared_ptr());
auto lower_bounds = as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
auto upper_bounds = as_type_ptr<op::Constant>(input_value(3).get_node_shared_ptr());
auto strides = as_type_ptr<op::Constant>(input_value(4).get_node_shared_ptr());
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
......
......@@ -50,7 +50,7 @@ void op::DynReshape::validate_and_infer_types()
set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
NODE_VALIDATION_CHECK(this,
......
......@@ -86,9 +86,9 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr());
auto lower_bounds = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto upper_bounds = as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
auto strides = as_type_ptr<op::Constant>(input_value(3).get_node_shared_ptr());
if (lower_bounds && upper_bounds && strides)
{
......
......@@ -49,7 +49,7 @@ void op::Interpolate::validate_and_infer_types()
}
}
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
auto out_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
size_t i = 0;
......
......@@ -58,7 +58,7 @@ void op::PriorBox::validate_and_infer_types()
set_input_is_relevant_to_shape(0);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(0)))
if (auto const_shape = as_type_ptr<op::Constant>(input_value(0).get_node_shared_ptr()))
{
NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2,
......
......@@ -72,7 +72,7 @@ void op::PriorBoxClustered::validate_and_infer_types()
set_input_is_relevant_to_shape(0);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(0)))
if (auto const_shape = as_type_ptr<op::Constant>(input_value(0).get_node_shared_ptr()))
{
NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2,
......
......@@ -123,12 +123,9 @@ static
template <typename T>
static PartialShape infer_output_shape(const op::Range* node, const element::Type& /* et */)
{
auto const_start =
dynamic_pointer_cast<op::Constant>(node->input_value(0).get_node_shared_ptr());
auto const_stop =
dynamic_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr());
auto const_step =
dynamic_pointer_cast<op::Constant>(node->input_value(2).get_node_shared_ptr());
auto const_start = as_type_ptr<op::Constant>(node->input_value(0).get_node_shared_ptr());
auto const_stop = as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
auto const_step = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
T start = static_cast<T>(0);
T stop = static_cast<T>(0);
......
......@@ -60,8 +60,7 @@ void op::Tile::validate_and_infer_types()
auto out_shape = PartialShape::dynamic(output_rank);
if (auto const_repeats =
dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
if (auto const_repeats = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
if (arg_shape.is_static())
{
......
......@@ -47,8 +47,7 @@ void op::Transpose::validate_and_infer_types()
set_input_is_relevant_to_shape(1);
if (auto input_const =
std::dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
if (auto input_const = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
auto permutation = input_const->get_axis_vector_val();
NODE_VALIDATION_CHECK(this,
......
......@@ -24,7 +24,7 @@
using namespace std;
using namespace ngraph;
const string op::MatMul::type_name{"MatMul"};
constexpr NodeTypeInfo op::MatMul::type_info;
op::MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B,
......
......@@ -29,8 +29,8 @@ namespace ngraph
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
static constexpr NodeTypeInfo type_info{"MatMul", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
MatMul() = default;
/// \brief Constructs an ScaleShift operation.
///
......
......@@ -80,7 +80,7 @@ AxisSet op::NormalizeL2::get_reduction_axes() const
{
AxisSet axes;
auto axes_input_node = input_value(1).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node))
if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{
axes = const_op->get_axis_set_val();
}
......
......@@ -43,7 +43,7 @@ NodeVector op::Squeeze::decompose_op() const
"doesn't support 'axes' input of other type than a Constant.");
// Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape();
......
......@@ -48,7 +48,7 @@ NodeVector op::Unsqueeze::decompose_op() const
auto axes_node = input_value(1).get_node_shared_ptr();
// Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape();
......
......@@ -47,7 +47,7 @@ AxisSet op::LRN::get_reduction_axes() const
{
AxisSet axes{1}; // channel axis as default
auto axes_input_node = input_value(1).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node))
if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{
axes = const_op->get_axis_set_val();
}
......
......@@ -59,7 +59,7 @@ op::TopK::TopK(const Output<Node>& arg,
size_t op::TopK::get_k() const
{
size_t k = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
if (auto const_op = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
k = const_op->get_vector<int64_t>()[0];
}
......
......@@ -41,13 +41,13 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
bool op::util::ArithmeticReduction::reduction_axes_constant() const
{
return dynamic_pointer_cast<op::Constant>(get_argument(1)) != nullptr;
return input_value(1).get_node()->is_type<op::Constant>();
}
const AxisSet op::util::ArithmeticReduction::get_reduction_axes() const
{
AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
if (auto const_op = as_type<op::Constant>(input_value(1).get_node()))
{
axes = const_op->get_axis_set_val();
}
......
......@@ -82,8 +82,10 @@ namespace ngraph
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool is_binary_elementwise_arithmetic() const override { return true; }
bool supports_auto_broadcast() const override { return true; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -88,8 +88,10 @@ namespace ngraph
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_comparison() const override { return true; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -84,8 +84,10 @@ namespace ngraph
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_logical() const override { return true; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -30,12 +30,7 @@ namespace ngraph
class FusedOp : public Op
{
public:
/// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
///
/// \return A vector of nodes comprising the sub-graph. The order of output
/// tensors must match the match output tensors of the FusedOp
virtual NodeVector decompose_op() const = 0;
bool supports_decompose() const override { return true; }
void validate_and_infer_types() override;
/// Pre and post validation hooks for op-specific actions
......
......@@ -40,13 +40,13 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
bool op::util::LogicalReduction::reduction_axes_constant() const
{
return dynamic_pointer_cast<op::Constant>(get_argument(1)) != nullptr;
return input_value(1).get_node()->is_type<op::Constant>();
}
const AxisSet op::util::LogicalReduction::get_reduction_axes() const
{
AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1)))
if (auto const_op = as_type<op::Constant>(input_value(1).get_node()))
{
axes = const_op->get_axis_set_val();
}
......
......@@ -67,6 +67,7 @@ namespace ngraph
public:
void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; }
};
}
}
......
......@@ -56,7 +56,7 @@ static shared_ptr<pattern::Matcher>
static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher)
{
return dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
return static_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
}
//`simplify_concat` identifies slices-concat sequences
......@@ -147,7 +147,7 @@ static bool simplify_concat(shared_ptr<Node> n)
}
// check that no other node uses slices and reshapes
if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg))
if (auto rcarg = as_type_ptr<op::Reshape>(carg))
{
auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order())
......@@ -316,9 +316,9 @@ static bool simplify_add(shared_ptr<Node> n)
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(shared_ptr<Node> n)
{
if (auto div = dynamic_pointer_cast<op::Divide>(n->get_argument(0)))
if (auto div = as_type_ptr<op::Divide>(n->input_value(0).get_node_shared_ptr()))
{
if (auto exp = dynamic_pointer_cast<op::Exp>(div->get_argument(0)))
if (auto exp = as_type_ptr<op::Exp>(div->input_value(0).get_node_shared_ptr()))
{
auto denom = div->get_argument(1);
auto diff =
......@@ -417,14 +417,14 @@ static bool simplify_reduction(shared_ptr<Node> n)
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = static_pointer_cast<T>(n);
auto broadcast = dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
auto broadcast = as_type_ptr<op::Broadcast>(n->input_value(0).get_node_shared_ptr());
if (!broadcast)
{
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false;
}
auto cnst = dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0));
auto cnst = as_type_ptr<op::Constant>(broadcast->input_value(0).get_node_shared_ptr());
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{
NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant";
......
......@@ -101,7 +101,7 @@ void pass::ConcatElimination::construct_concat_elimination()
auto pattern_map = m.get_pattern_map();
auto op = pattern_map[op_label];
auto root = std::dynamic_pointer_cast<op::Concat>(m.get_match_root());
auto root = as_type_ptr<op::Concat>(m.get_match_root());
if (root && (root->get_input_shape(0) == root->get_output_shape(0)))
{
NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n";
......@@ -122,7 +122,7 @@ bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> f
bool modify_graph = false;
auto has_multiple_inputs = [](std::shared_ptr<Node> n) {
auto input_size = n->get_input_size();
auto root = std::dynamic_pointer_cast<op::Concat>(n);
auto root = as_type_ptr<op::Concat>(n);
return (root && input_size > 1);
};
......@@ -178,7 +178,7 @@ void ngraph::pass::SelfConcatFusion::construct_concat_patterns(
if (matcher->match(n))
{
auto concat_op = matcher->get_pattern_map()[concat_op_label];
if (!std::dynamic_pointer_cast<op::Concat>(concat_op))
if (!concat_op->is_type<op::Concat>())
{
NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched "
<< concat_op->get_name() << " instead of a self concat";
......
......@@ -35,7 +35,7 @@ static shared_ptr<op::Constant>
{
vector<T> out_vec(shape_size(reduction_node->get_shape()));
if (auto max = dynamic_pointer_cast<op::Max>(reduction_node))
if (auto max = as_type_ptr<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
......@@ -43,7 +43,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
max->get_reduction_axes());
}
else if (auto min = dynamic_pointer_cast<op::Min>(reduction_node))
else if (auto min = as_type_ptr<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
......@@ -51,7 +51,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
min->get_reduction_axes());
}
else if (auto prod = dynamic_pointer_cast<op::Product>(reduction_node))
else if (auto prod = as_type_ptr<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(),
......@@ -59,7 +59,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(),
prod->get_reduction_axes());
}
else if (auto sum = dynamic_pointer_cast<op::Sum>(reduction_node))
else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
......
This diff is collapsed.
......@@ -57,14 +57,13 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match);
auto dequantize_op = as_type_ptr<op::Dequantize>(dequant_match);
auto scale = dynamic_pointer_cast<op::Constant>(
dequant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = dynamic_pointer_cast<op::Constant>(
dequant_match->input(2).get_source_output().get_node_shared_ptr());
auto scale = as_type_ptr<op::Constant>(dequant_match->input_value(1).get_node_shared_ptr());
auto offset =
as_type_ptr<op::Constant>(dequant_match->input_value(2).get_node_shared_ptr());
NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op));
auto type = constant_match->get_element_type();
......
......@@ -28,7 +28,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
{
vector<char> out_vec(shape_size(reduction_node->get_shape()));
if (auto all = dynamic_pointer_cast<::ngraph::op::All>(reduction_node))
if (auto all = as_type_ptr<::ngraph::op::All>(reduction_node))
{
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(),
......@@ -36,7 +36,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
reduction_node->get_shape(),
all->get_reduction_axes());
}
else if (auto any = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node))
else if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node))
{
runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(),
......
......@@ -59,9 +59,9 @@ void pass::ConstantFolding::construct_constant_quantize()
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
auto quantize_op = as_type_ptr<op::Quantize>(quant_match);
NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op));
......
......@@ -40,10 +40,9 @@ using namespace ngraph;
bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) ||
std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Not>(n) || std::dynamic_pointer_cast<op::Relu>(n) ||
std::dynamic_pointer_cast<op::Sign>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
return n->is_type<op::Abs>() || n->is_type<op::Ceiling>() || n->is_type<op::Floor>() ||
n->is_type<op::Negative>() || n->is_type<op::Not>() || n->is_type<op::Relu>() ||
n->is_type<op::Sign>() || n->is_type<op::Sqrt>();
}
template <class T>
......@@ -52,7 +51,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
NodeExecutorTy func)
{
// check sqrt arg
if (std::dynamic_pointer_cast<op::Sqrt>(unary))
if (unary->is_type<op::Sqrt>())
{
std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < T(0); }))
......@@ -75,42 +74,42 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
}
else
{
if (std::dynamic_pointer_cast<op::Abs>(unary))
if (unary->is_type<op::Abs>())
{
runtime::reference::abs<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Ceiling>(unary))
else if (unary->is_type<op::Ceiling>())
{
runtime::reference::ceiling<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Floor>(unary))
else if (unary->is_type<op::Floor>())
{
runtime::reference::floor<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Negative>(unary))
else if (unary->is_type<op::Negative>())
{
runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Not>(unary))
else if (unary->is_type<op::Not>())
{
runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Relu>(unary))
else if (unary->is_type<op::Relu>())
{
runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sign>(unary))
else if (unary->is_type<op::Sign>())
{
runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Sqrt>(unary))
else if (unary->is_type<op::Sqrt>())
{
runtime::reference::sqrt<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
......@@ -129,8 +128,7 @@ void pass::ConstantFolding::construct_constant_unary()
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_ue = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::util::UnaryElementwiseArithmetic>()(n) ||
pattern::has_class<op::Not>()(n));
return n->is_unary_elementwise_arithmetic() || pattern::has_class<op::Not>()(n);
};
auto ue = std::make_shared<pattern::op::Any>(constant_label, is_ue, NodeVector{constant_label});
......@@ -140,7 +138,7 @@ void pass::ConstantFolding::construct_constant_unary()
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto unary_match = m.get_match_root();
if (!is_supported_unary_op(unary_match))
......
......@@ -570,7 +570,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
{
if (is_used(n.get()))
{
if (dynamic_pointer_cast<op::Convolution>(n) == nullptr)
if (!n->is_type<op::Convolution>())
{
NGRAPH_DEBUG << "Not all live users of element wise operation are Convolution";
return false;
......@@ -821,16 +821,15 @@ void pass::CoreFusion::construct_zero_padded_reshaped_conv()
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]);
auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op)
{
NGRAPH_DEBUG << "Pad value must be a constant";
return false;
}
const auto& matched_conv =
std::static_pointer_cast<ngraph::op::Convolution>(pattern_map[conv_label]);
const auto& matched_pad = std::static_pointer_cast<ngraph::op::Pad>(pattern_map[pad_label]);
const auto& matched_conv = as_type_ptr<ngraph::op::Convolution>(pattern_map[conv_label]);
const auto& matched_pad = as_type_ptr<ngraph::op::Pad>(pattern_map[pad_label]);
const auto& matched_reshape =
std::static_pointer_cast<ngraph::op::Reshape>(pattern_map[reshape_label]);
......@@ -905,7 +904,7 @@ void pass::CoreFusion::construct_zero_padded_conv()
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]);
auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op)
{
NGRAPH_DEBUG << "Pad value must be a constant";
......@@ -976,7 +975,7 @@ void pass::CoreFusion::construct_zero_padded_conv_backprop_filters()
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]);
auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op)
{
NGRAPH_DEBUG << "Pad value must be a constant";
......@@ -1036,7 +1035,7 @@ void pass::CoreFusion::construct_conv_bias()
auto pbcast = make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pbcast_label = make_shared<pattern::op::Label>(pbcast, nullptr, NodeVector{pbcast});
auto reshape_pred = [](shared_ptr<Node> node) -> bool {
if (auto reshape = dynamic_pointer_cast<op::Reshape>(node))
if (auto reshape = as_type_ptr<op::Reshape>(node))
{
auto ishape = reshape->get_input_shape(0);
auto oshape = reshape->get_shape();
......@@ -1066,7 +1065,7 @@ void pass::CoreFusion::construct_conv_bias()
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto conv_m = dynamic_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0));
auto conv_m = as_type_ptr<op::Convolution>(m.get_match_root()->get_argument(0));
if (conv_m == nullptr)
{
......@@ -1138,7 +1137,7 @@ void pass::CoreFusion::construct_conv_bias_add()
auto add_m = m.get_match_root();
auto pattern_map = m.get_pattern_map();
auto conv_m = dynamic_pointer_cast<op::ConvolutionBias>(add_m->get_argument(1));
auto conv_m = as_type_ptr<op::ConvolutionBias>(add_m->get_argument(1));
auto add_input_m = add_m->get_argument(0);
if (!conv_m)
......
......@@ -29,7 +29,7 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node))
if (node->supports_decompose())
{
if (m_has_direct_support && m_has_direct_support(*node))
{
......@@ -37,9 +37,9 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
return modified;
}
auto subgraph_outputs = fused_op->decompose_op();
auto subgraph_outputs = node->decompose_op();
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments());
for (auto subgraph_node : subgraph)
{
run_on_node(subgraph_node);
......@@ -51,12 +51,11 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
set<descriptor::Input*> fop_users{begin(node->get_outputs().at(i).get_inputs()),
end(node->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
{
if (auto goe =
dynamic_cast<op::GetOutputElement*>(fop_user->get_raw_pointer_node()))
if (auto goe = as_type<op::GetOutputElement>(fop_user->get_raw_pointer_node()))
{
Output<Node> goe_output = goe->get_as_output();
if (goe_output.get_index() == i && !goe->get_output_inputs(0).empty())
......@@ -78,12 +77,12 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
}
}
}
if (i != fused_op->get_output_size())
if (i != node->get_output_size())
{
throw ngraph_error("While replacing " + node->get_name() +
", mismatch between op output count and outputs of the decomposed "
"subgraph. Expected: " +
to_string(fused_op->get_output_size()) + " Got: " + to_string(i));
to_string(node->get_output_size()) + " Got: " + to_string(i));
}
modified = true;
}
......
......@@ -24,27 +24,19 @@
using namespace std;
using namespace ngraph;
template <typename optype>
static bool broadcast_and_replace(std::shared_ptr<ngraph::Node>& node)
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<ngraph::Node> node)
{
if (auto op = std::dynamic_pointer_cast<optype>(node))
if (node->supports_auto_broadcast())
{
if (op->get_autob().m_type != op::AutoBroadcastType::NONE)
if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
{
auto new_args = pass::explicit_broadcast<optype>(op);
auto new_args = pass::explicit_broadcast(node);
for (size_t i = 0; i < new_args.size(); i++)
{
op->input(i).replace_source_output(new_args[i]->output(0));
node->input(i).replace_source_output(new_args[i]->output(0));
}
return true;
}
}
return false;
}
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<ngraph::Node> node)
{
return broadcast_and_replace<op::util::BinaryElementwiseArithmetic>(node) ||
broadcast_and_replace<op::util::BinaryElementwiseComparison>(node) ||
broadcast_and_replace<op::util::BinaryElementwiseLogical>(node);
}
......@@ -23,22 +23,23 @@ namespace ngraph
{
namespace pass
{
template <typename T>
NodeVector explicit_broadcast(std::shared_ptr<T>& node)
NodeVector explicit_broadcast(std::shared_ptr<Node>& node)
{
NodeVector rc;
if (node->get_autob().m_type == op::AutoBroadcastType::NONE)
{
rc = node->get_arguments();
}
else if (node->get_autob().m_type == op::AutoBroadcastType::NUMPY)
{
rc = op::numpy_style_broadcast(node->get_arguments());
}
else
if (node->supports_auto_broadcast())
{
throw ngraph_error("Unsupported implicit broadcast type");
if (node->get_autob().m_type == op::AutoBroadcastType::NONE)
{
rc = node->get_arguments();
}
else if (node->get_autob().m_type == op::AutoBroadcastType::NUMPY)
{
rc = op::numpy_style_broadcast(node->get_arguments());
}
else
{
throw ngraph_error("Unsupported implicit broadcast type");
}
}
return rc;
}
......
......@@ -66,7 +66,7 @@ bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function)
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
replace_node(sclb, sclb->as_constant());
......
......@@ -58,7 +58,7 @@ bool pass::Liveness::run_on_function(shared_ptr<Function> function)
}
for (const shared_ptr<Node>& node : ops)
{
if (auto constant_node = dynamic_pointer_cast<op::Constant>(node))
if (auto constant_node = as_type_ptr<op::Constant>(node))
{
for (auto& output : constant_node->outputs())
{
......
......@@ -52,8 +52,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{
auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently
if (!std::dynamic_pointer_cast<op::Concat>(node) &&
!std::dynamic_pointer_cast<op::Slice>(node))
if (!node->is_type<op::Concat>() && !node->is_type<op::Slice>())
{
if (auto op_annotations = op->get_op_annotations())
{
......@@ -66,7 +65,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
// For destructive kernel, this should be the last use
// Non-destructive kernels can pass through if memory sharing is disabled
if ((node->liveness_free_list.count(input) != 0 ||
std::dynamic_pointer_cast<op::GetOutputElement>(node) ||
node->is_type<op::GetOutputElement>() ||
(m_disable_memory_sharing && !oi_pair.destructive &&
!input_node->is_parameter() && !input_node->is_constant())) &&
node->liveness_new_list.count(output) != 0)
......
......@@ -133,7 +133,7 @@ bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
// Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<op::ScalarConstantLikeBase>(n);
auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr)
{
replace_node(sclb, sclb->as_constant());
......
......@@ -52,7 +52,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op];
auto r1 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = as_type_ptr<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape())
{
......@@ -152,9 +152,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
void pass::ReshapeElimination::construct_dot_transpose_pattern()
{
// dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](shared_ptr<Node> n) {
return static_cast<bool>(dynamic_pointer_cast<op::Dot>(n));
};
auto dot_pred = [](shared_ptr<Node> n) { return n->is_type<op::Dot>(); };
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
......@@ -232,7 +230,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
// Need to check if the user of the last bound op is a reshape since the last reshape is
// allowed to have fan-out but the matcher will discard any reshape if it has fan-out
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op))
if (user_of_last_bound_reshape_op->is_type<op::Reshape>())
{
reshape_node_vector.push_back(user_of_last_bound_reshape_op);
last_bound_reshape_op = reshape_node_vector.back();
......@@ -251,7 +249,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end();
it++)
{
auto r = std::dynamic_pointer_cast<op::Reshape>(*it);
auto r = as_type_ptr<op::Reshape>(*it);
// Check that the input to r is the last reshape stored in the
// subpattern vector
......@@ -286,9 +284,9 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
continue;
}
auto first_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.front());
auto first_reshape = as_type_ptr<op::Reshape>(sub_pattern.front());
auto input_to_first_reshape = first_reshape->get_argument(0);
auto last_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.back());
auto last_reshape = as_type_ptr<op::Reshape>(sub_pattern.back());
auto new_input_order = first_reshape->get_input_order();
auto new_out_shape = last_reshape->get_shape();
......
......@@ -47,7 +47,7 @@ using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>;
static string describe_reshape(shared_ptr<Node> node)
{
stringstream ss;
auto reshape = dynamic_pointer_cast<op::Reshape>(node);
auto reshape = as_type_ptr<op::Reshape>(node);
ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
......@@ -167,14 +167,14 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr();
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
if (n->is_unary_elementwise_arithmetic())
{
Swimmer nsw{unary->input(0), csw.reshape};
Swimmer nsw{n->input(0), csw.reshape};
work_queue.push_back(nsw);
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0);
<< n->get_name() << " to " << n->get_argument(0);
}
else if (dynamic_pointer_cast<op::Broadcast>(n))
else if (n->is_type<op::Broadcast>())
{
auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes();
......@@ -324,7 +324,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
}
}
static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
static void sink_unary(shared_ptr<Node> n,
ReshapeMap& reorders,
set<shared_ptr<Node>>& /* reshapes_to_delete */)
{
......@@ -333,7 +333,7 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
write_reshapemap(reorders, n, arg_reshape);
}
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
static void sink_binary(shared_ptr<Node> binary,
ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete)
{
......@@ -533,31 +533,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
results.push_back(n);
}
if (auto reshape = dynamic_pointer_cast<op::Reshape>(n))
if (auto reshape = as_type_ptr<op::Reshape>(n))
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
else if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n))
else if (n->is_unary_elementwise_arithmetic())
{
sink_unary(unary, reorders, reshapes_to_delete);
sink_unary(n, reorders, reshapes_to_delete);
}
else if (auto binary = dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n))
else if (n->is_binary_elementwise_arithmetic())
{
sink_binary(binary, reorders, reshapes_to_delete);
sink_binary(n, reorders, reshapes_to_delete);
}
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
else if (auto goe = as_type_ptr<op::GetOutputElement>(n))
{
write_reshapemap(reorders, goe, create_default_reshape(goe));
}
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
else if (auto quantize = as_type_ptr<op::Quantize>(n))
{
sink_quantize(quantize, reorders, reshapes_to_delete);
}
else if (auto dequantize = dynamic_pointer_cast<op::Dequantize>(n))
else if (auto dequantize = as_type_ptr<op::Dequantize>(n))
{
sink_dequantize(dequantize, reorders, reshapes_to_delete);
}
else if (auto slice = dynamic_pointer_cast<op::Slice>(n))
else if (auto slice = as_type_ptr<op::Slice>(n))
{
// A heuristic. If Reshape has multiple slice users, if sunk
// it will be replicated by the number of its users
......@@ -578,11 +578,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
materialize_shapes(n, reorders, reshapes_to_delete);
}
}
else if (auto pad = dynamic_pointer_cast<op::Pad>(n))
else if (auto pad = as_type_ptr<op::Pad>(n))
{
sink_pad(pad, reorders, reshapes_to_delete);
}
else if (auto concat = dynamic_pointer_cast<op::Concat>(n))
else if (auto concat = as_type_ptr<op::Concat>(n))
{
sink_concat(concat, reorders, reshapes_to_delete);
}
......
......@@ -33,7 +33,7 @@ void pass::ValidateGraph::validate_parameters(const Function& function)
auto parameters = function.get_parameters();
for (auto node : function.get_ops())
{
shared_ptr<op::Parameter> p = dynamic_pointer_cast<op::Parameter>(node);
shared_ptr<op::Parameter> p = as_type_ptr<op::Parameter>(node);
if (nullptr != p)
{
auto it = find_if(parameters.begin(),
......
......@@ -163,7 +163,7 @@ static std::string label_edge(const std::shared_ptr<Node>& /* src */,
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
{
size_t output = 0;
if (auto goe = dst->as_type<op::GetOutputElement>())
if (auto goe = as_type_ptr<op::GetOutputElement>(dst))
{
output = goe->get_as_output().get_index();
}
......@@ -223,7 +223,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
traverse_nodes(f, [&](shared_ptr<Node> node) {
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node))
if (auto ck = as_type_ptr<ngraph::op::CompiledKernel>(node))
{
// print sub-graph
auto nodes_list = ck->get_node_list();
......@@ -416,7 +416,7 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{
rc += "\\n" + node->get_name();
}
if (auto ck = node->as_type<ngraph::op::CompiledKernel>())
if (auto ck = as_type_ptr<ngraph::op::CompiledKernel>(node))
{
rc += "\\n{";
// add sub-graph node names
......
......@@ -111,7 +111,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
continue;
}
if (auto concat = dynamic_pointer_cast<op::Concat>(n))
if (auto concat = as_type_ptr<op::Concat>(n))
{
NodeVector non_zero_dim_args;
for (auto arg : concat->get_arguments())
......
......@@ -28,6 +28,11 @@ namespace ngraph
{
namespace pattern
{
constexpr NodeTypeInfo op::AnyOf::type_info;
constexpr NodeTypeInfo op::Any::type_info;
constexpr NodeTypeInfo op::Label::type_info;
constexpr NodeTypeInfo op::Skip::type_info;
std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
......@@ -227,23 +232,23 @@ namespace ngraph
<< "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
if (auto label_node = as_type_ptr<op::Label>(pattern_node))
{
return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map));
}
if (auto skip_node = std::dynamic_pointer_cast<op::Skip>(
pattern_node)) // matches PatternSkipOp semantics
if (auto skip_node =
as_type_ptr<op::Skip>(pattern_node)) // matches PatternSkipOp semantics
{
return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map));
}
if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node))
if (auto any_node = as_type_ptr<op::Any>(pattern_node))
{
return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
}
if (auto any_of_node = std::dynamic_pointer_cast<op::AnyOf>(pattern_node))
if (auto any_of_node = as_type_ptr<op::AnyOf>(pattern_node))
{
return abort_match(watermark, match_any_of(any_of_node, graph_node, pattern_map));
}
......
......@@ -40,9 +40,7 @@ namespace ngraph
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class()
{
auto pred = [](std::shared_ptr<Node> node) -> bool {
return std::dynamic_pointer_cast<T>(node) != nullptr;
};
auto pred = [](std::shared_ptr<Node> node) -> bool { return node->is_type<T>(); };
return pred;
}
......@@ -109,7 +107,7 @@ namespace ngraph
std::shared_ptr<T> matched;
for (auto arg : node->get_arguments())
{
if (auto t_casted = std::dynamic_pointer_cast<T>(arg))
if (auto t_casted = as_type_ptr<T>(arg))
{
if (matched)
{
......
......@@ -29,6 +29,9 @@ namespace ngraph
class Any : public Pattern
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternAny", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
/// shape.
Any(const element::Type& type,
......@@ -53,12 +56,6 @@ namespace ngraph
wrapped_nodes)
{
}
const std::string& description() const override
{
static std::string desc = "Any";
return desc;
}
};
}
}
......
......@@ -35,6 +35,9 @@ namespace ngraph
class AnyOf : public Pattern
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
/// \sa shape.
AnyOf(const element::Type& type,
......@@ -64,12 +67,6 @@ namespace ngraph
wrapped_nodes)
{
}
const std::string& description() const override
{
static std::string desc = "AnyOf";
return desc;
}
};
}
}
......
......@@ -31,6 +31,9 @@ namespace ngraph
class Label : public Pattern
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a Label node containing a sub-pattern described by \sa type and
/// \sa shape.
///
......@@ -74,12 +77,6 @@ namespace ngraph
wrapped_nodes)
{
}
const std::string& description() const override
{
static std::string desc = "Label";
return desc;
}
};
}
}
......
......@@ -31,17 +31,14 @@ namespace ngraph
class Skip : public Pattern
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern(NodeVector{arg}, predicate)
{
set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
}
const std::string& description() const override
{
static std::string desc = "Skip";
return desc;
}
};
}
}
......
......@@ -31,14 +31,14 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
auto broadcast_op = std::make_shared<ngraph::op::Broadcast>(src_op, Shape{}, AxisSet{});
auto target_op = std::make_shared<pattern::op::AnyOf>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{broadcast_op});
auto target_op =
std::make_shared<pattern::op::AnyOf>(element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return node->is_unary_elementwise_arithmetic() ||
node->is_binary_elementwise_arithmetic();
},
NodeVector{broadcast_op});
auto callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we
......
......@@ -34,14 +34,14 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
std::make_shared<pattern::op::Skip>(replicate_op, [](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node);
});
auto target_op = std::make_shared<pattern::op::AnyOf>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{skip_op});
auto target_op =
std::make_shared<pattern::op::AnyOf>(element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return node->is_unary_elementwise_arithmetic() ||
node->is_binary_elementwise_arithmetic();
},
NodeVector{skip_op});
auto callback = [](pattern::Matcher& m) {
bool replaced_any = false;
......
......@@ -418,7 +418,7 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s
traverse_nodes(const_cast<Function*>(func.get()),
[&](shared_ptr<Node> node) {
if (auto c = dynamic_pointer_cast<op::Constant>(node))
if (auto c = node->as_type<op::Constant>())
{
uint32_t size =
static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
......@@ -624,8 +624,7 @@ ParameterVector JSONDeserializer::deserialize_parameter_vector(json json_paramet
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : json_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
params.push_back(as_type_ptr<op::Parameter>(deserialize_node_reference(param_ref)));
}
return params;
}
......@@ -646,7 +645,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json func_js)
for (auto& result_ref : func_result)
{
auto fr = deserialize_node_reference(result_ref);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr))
if (auto res = as_type_ptr<op::Result>(fr))
{
result.push_back(res);
// make sure we have `op::Result` on top of all outputs
......
......@@ -70,7 +70,7 @@ std::shared_ptr<Function>
ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++)
{
new_parameters[i] = std::dynamic_pointer_cast<op::Parameter>(m[new_parameters[i].get()]);
new_parameters[i] = as_type_ptr<op::Parameter>(m[new_parameters[i].get()]);
// If the replacement for a Parameter is not itself a Parameter, we must have replaced it
// with a constant. We will insert a dead Parameter into the clone's parameters, in order
......
......@@ -250,8 +250,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
NodeVector result_nodes;
for (auto node : bprop->get_results())
{
auto result =
std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map.at(node.get()));
auto result = as_type_ptr<op::Result>(fprop_cache.node_param_map.at(node.get()));
if (!result)
{
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
......@@ -266,15 +265,15 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ParameterVector bprop_input_params;
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(std::dynamic_pointer_cast<op::Parameter>(
fprop_cache.node_param_map.at(param.get())));
bprop_input_params.push_back(
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(param.get())));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map.at(x)));
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(x)));
}
return bprop_input_params;
};
......@@ -286,7 +285,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ngraph::traverse_nodes(
result_nodes,
[&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) {
auto pnode = std::dynamic_pointer_cast<op::Parameter>(node);
auto pnode = as_type_ptr<op::Parameter>(node);
if (pnode != nullptr &&
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
cloned_bprop_inputs.end())
......@@ -302,7 +301,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto fpirn : fprop_cache.fprop_output_nodes)
{
auto fpir = fpirn->shared_from_this();
if (std::dynamic_pointer_cast<op::Result>(fpir))
if (as_type_ptr<op::Result>(fpir))
{
throw ngraph_error("Expected op::Result in fprop->get_results()");
}
......
......@@ -321,7 +321,9 @@ TEST(pattern, matcher)
auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto is_bea = [](std::shared_ptr<Node> node) -> bool {
return node->is_binary_elementwise_arithmetic();
};
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto add_ab = a + b;
ASSERT_TRUE(n.match(bea, add_ab));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment