Unverified Commit 15ceedb7 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Remove some dynamic_ptr_casts (#3622)

* Remove some dynamic_ptr_casts

* More removals

* Review comments

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarAdam Procter <adam.m.procter@intel.com>
parent ecf7a396
...@@ -52,7 +52,7 @@ Function::Function(const OutputVector& results, ...@@ -52,7 +52,7 @@ Function::Function(const OutputVector& results,
, m_unique_name("Function_" + to_string(m_instance_id)) , m_unique_name("Function_" + to_string(m_instance_id))
{ {
if (std::any_of(results.cbegin(), results.cend(), [](Output<Node> n) { if (std::any_of(results.cbegin(), results.cend(), [](Output<Node> n) {
return std::dynamic_pointer_cast<op::Result>(n.get_node_shared_ptr()); return as_type_ptr<op::Result>(n.get_node_shared_ptr());
})) }))
{ {
throw ngraph_error( throw ngraph_error(
...@@ -76,7 +76,7 @@ Function::Function(const NodeVector& results, ...@@ -76,7 +76,7 @@ Function::Function(const NodeVector& results,
, m_unique_name("Function_" + to_string(m_instance_id)) , m_unique_name("Function_" + to_string(m_instance_id))
{ {
if (std::any_of(results.cbegin(), results.cend(), [](std::shared_ptr<Node> n) { if (std::any_of(results.cbegin(), results.cend(), [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Result>(n); return as_type_ptr<op::Result>(n);
})) }))
{ {
throw ngraph_error( throw ngraph_error(
......
...@@ -306,7 +306,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& ...@@ -306,7 +306,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
ResultVector cloned_results; ResultVector cloned_results;
for (shared_ptr<Node> node : func.get_results()) for (shared_ptr<Node> node : func.get_results())
{ {
auto result = std::dynamic_pointer_cast<op::Result>(node_map.at(node.get())); auto result = as_type_ptr<op::Result>(node_map.at(node.get()));
if (!result) if (!result)
{ {
throw ngraph_error("Results should be of type op::Result"); throw ngraph_error("Results should be of type op::Result");
...@@ -316,7 +316,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& ...@@ -316,7 +316,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
std::vector<std::shared_ptr<op::Parameter>> cloned_params; std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func.get_parameters()) for (auto param : func.get_parameters())
{ {
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.at(param.get()))); cloned_params.push_back(as_type_ptr<op::Parameter>(node_map.at(param.get())));
} }
// create and return cloned function // create and return cloned function
...@@ -325,7 +325,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& ...@@ -325,7 +325,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant) bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>& reduce_constant)
{ {
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr())) if (auto rc = as_type_ptr<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{ {
auto cshape = rc->get_shape(); auto cshape = rc->get_shape();
size_t n = shape_size(cshape); size_t n = shape_size(cshape);
......
...@@ -454,6 +454,12 @@ void Node::clear_control_dependents() ...@@ -454,6 +454,12 @@ void Node::clear_control_dependents()
} }
} }
const op::AutoBroadcastSpec& Node::get_autob() const
{
static op::AutoBroadcastSpec s_spec;
return s_spec;
}
namespace ngraph namespace ngraph
{ {
ostream& operator<<(ostream& out, const Node& node) ostream& operator<<(ostream& out, const Node& node)
......
...@@ -56,6 +56,7 @@ namespace ngraph ...@@ -56,6 +56,7 @@ namespace ngraph
namespace op namespace op
{ {
struct AutoBroadcastSpec;
class Constant; class Constant;
} // namespace op } // namespace op
...@@ -157,37 +158,21 @@ namespace ngraph ...@@ -157,37 +158,21 @@ namespace ngraph
{ {
return &get_type_info() == &NodeType::type_info; return &get_type_info() == &NodeType::type_info;
} }
virtual bool is_unary_elementwise_arithmetic() const { return false; }
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise; virtual bool is_binary_elementwise_arithmetic() const { return false; }
template <typename NodeType> virtual bool is_binary_elementwise_comparison() const { return false; }
std::shared_ptr<NodeType> as_type_ptr() virtual bool is_binary_elementwise_logical() const { return false; }
{ /// \returns true if node supports autobroadcast operations
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this()) virtual bool supports_auto_broadcast() const { return false; }
: std::shared_ptr<NodeType>(); /// \returns the autobroadcasr spec
} virtual const op::AutoBroadcastSpec& get_autob() const;
/// \returns true if the node can decompose
/// Casts a Node to a shared_ptr<T> if is of type T, nullptr otherwise; virtual bool supports_decompose() const { return false; }
template <typename NodeType> /// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
std::shared_ptr<const NodeType> as_type_ptr() const ///
{ /// \return A vector of nodes comprising the sub-graph. The order of output
return is_type<NodeType>() ? std::static_pointer_cast<NodeType>(shared_from_this()) /// tensors must match the match output tensors of the FusedOp
: std::shared_ptr<NodeType>(); virtual NodeVector decompose_op() const { return NodeVector(); }
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
NodeType* as_type()
{
return is_type<NodeType>() ? static_cast<NodeType*>(this) : nullptr;
}
/// Casts a Node to a T* if is of type T, nullptr otherwise;
template <typename NodeType>
const NodeType* as_type() const
{
return is_type<NodeType>() ? static_cast<const NodeType*>(this) : nullptr;
}
/// Returns the NodeTypeInfo for the node's class. /// Returns the NodeTypeInfo for the node's class.
/// During transition to type_info, returns a dummy type_info for Node if the class /// During transition to type_info, returns a dummy type_info for Node if the class
/// has not been updated yet. /// has not been updated yet.
...@@ -504,6 +489,36 @@ namespace ngraph ...@@ -504,6 +489,36 @@ namespace ngraph
size_t m_placement_index = placement_invalid; size_t m_placement_index = placement_invalid;
}; };
/// Casts a Node* to a NodeType* if it is of type NodeType, nullptr otherwise
template <typename NodeType>
NodeType* as_type(Node* node)
{
return node->template is_type<NodeType>() ? static_cast<NodeType*>(node) : nullptr;
}
/// Casts a Node* to a NodePtr* if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
const NodeType* as_type(const Node* node)
{
return node->template is_type<NodeType>() ? static_cast<const NodeType*>(node) : nullptr;
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<NodeType> as_type_ptr(std::shared_ptr<Node> node_ptr)
{
return node_ptr->template is_type<NodeType>() ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// Casts a Node to a shared_ptr<NodePtr> if it is of type NodePtr, nullptr otherwise
template <typename NodeType>
std::shared_ptr<const NodeType> as_type_ptr(std::shared_ptr<const Node> node_ptr)
{
return node_ptr->template is_type<NodeType>() ? std::static_pointer_cast<NodeType>(node_ptr)
: std::shared_ptr<NodeType>();
}
/// \brief A handle for one of a node's inputs. /// \brief A handle for one of a node's inputs.
template <typename NodeType> template <typename NodeType>
class Input class Input
......
...@@ -107,9 +107,9 @@ void op::DynReplaceSlice::validate_and_infer_types() ...@@ -107,9 +107,9 @@ void op::DynReplaceSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4); set_input_is_relevant_to_shape(4);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr()); auto lower_bounds = as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr()); auto upper_bounds = as_type_ptr<op::Constant>(input_value(3).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(input_value(4).get_node_shared_ptr()); auto strides = as_type_ptr<op::Constant>(input_value(4).get_node_shared_ptr());
// TODO(amprocte): We can get a bit more information here about the ranks of arg and // TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes. // replacement by inspecting the attributes.
......
...@@ -50,7 +50,7 @@ void op::DynReshape::validate_and_infer_types() ...@@ -50,7 +50,7 @@ void op::DynReshape::validate_and_infer_types()
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())) if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>(); std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -86,9 +86,9 @@ void op::DynSlice::validate_and_infer_types() ...@@ -86,9 +86,9 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(2); set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3); set_input_is_relevant_to_shape(3);
auto lower_bounds = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()); auto lower_bounds = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto upper_bounds = dynamic_pointer_cast<op::Constant>(input_value(2).get_node_shared_ptr()); auto upper_bounds = as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
auto strides = dynamic_pointer_cast<op::Constant>(input_value(3).get_node_shared_ptr()); auto strides = as_type_ptr<op::Constant>(input_value(3).get_node_shared_ptr());
if (lower_bounds && upper_bounds && strides) if (lower_bounds && upper_bounds && strides)
{ {
......
...@@ -49,7 +49,7 @@ void op::Interpolate::validate_and_infer_types() ...@@ -49,7 +49,7 @@ void op::Interpolate::validate_and_infer_types()
} }
} }
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
auto out_shape = static_cast<const int64_t*>(const_shape->get_data_ptr()); auto out_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
size_t i = 0; size_t i = 0;
......
...@@ -58,7 +58,7 @@ void op::PriorBox::validate_and_infer_types() ...@@ -58,7 +58,7 @@ void op::PriorBox::validate_and_infer_types()
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(0))) if (auto const_shape = as_type_ptr<op::Constant>(input_value(0).get_node_shared_ptr()))
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2, shape_size(const_shape->get_shape()) == 2,
......
...@@ -72,7 +72,7 @@ void op::PriorBoxClustered::validate_and_infer_types() ...@@ -72,7 +72,7 @@ void op::PriorBoxClustered::validate_and_infer_types()
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(0))) if (auto const_shape = as_type_ptr<op::Constant>(input_value(0).get_node_shared_ptr()))
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2, shape_size(const_shape->get_shape()) == 2,
......
...@@ -123,12 +123,9 @@ static ...@@ -123,12 +123,9 @@ static
template <typename T> template <typename T>
static PartialShape infer_output_shape(const op::Range* node, const element::Type& /* et */) static PartialShape infer_output_shape(const op::Range* node, const element::Type& /* et */)
{ {
auto const_start = auto const_start = as_type_ptr<op::Constant>(node->input_value(0).get_node_shared_ptr());
dynamic_pointer_cast<op::Constant>(node->input_value(0).get_node_shared_ptr()); auto const_stop = as_type_ptr<op::Constant>(node->input_value(1).get_node_shared_ptr());
auto const_stop = auto const_step = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
dynamic_pointer_cast<op::Constant>(node->input_value(1).get_node_shared_ptr());
auto const_step =
dynamic_pointer_cast<op::Constant>(node->input_value(2).get_node_shared_ptr());
T start = static_cast<T>(0); T start = static_cast<T>(0);
T stop = static_cast<T>(0); T stop = static_cast<T>(0);
......
...@@ -60,8 +60,7 @@ void op::Tile::validate_and_infer_types() ...@@ -60,8 +60,7 @@ void op::Tile::validate_and_infer_types()
auto out_shape = PartialShape::dynamic(output_rank); auto out_shape = PartialShape::dynamic(output_rank);
if (auto const_repeats = if (auto const_repeats = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
if (arg_shape.is_static()) if (arg_shape.is_static())
{ {
......
...@@ -47,8 +47,7 @@ void op::Transpose::validate_and_infer_types() ...@@ -47,8 +47,7 @@ void op::Transpose::validate_and_infer_types()
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
if (auto input_const = if (auto input_const = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
std::dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
auto permutation = input_const->get_axis_vector_val(); auto permutation = input_const->get_axis_vector_val();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::MatMul::type_name{"MatMul"}; constexpr NodeTypeInfo op::MatMul::type_info;
op::MatMul::MatMul(const Output<Node>& A, op::MatMul::MatMul(const Output<Node>& A,
const Output<Node>& B, const Output<Node>& B,
......
...@@ -29,8 +29,8 @@ namespace ngraph ...@@ -29,8 +29,8 @@ namespace ngraph
{ {
public: public:
NGRAPH_API NGRAPH_API
static const std::string type_name; static constexpr NodeTypeInfo type_info{"MatMul", 0};
const std::string& description() const override { return type_name; } const NodeTypeInfo& get_type_info() const override { return type_info; }
MatMul() = default; MatMul() = default;
/// \brief Constructs an ScaleShift operation. /// \brief Constructs an ScaleShift operation.
/// ///
......
...@@ -80,7 +80,7 @@ AxisSet op::NormalizeL2::get_reduction_axes() const ...@@ -80,7 +80,7 @@ AxisSet op::NormalizeL2::get_reduction_axes() const
{ {
AxisSet axes; AxisSet axes;
auto axes_input_node = input_value(1).get_node_shared_ptr(); auto axes_input_node = input_value(1).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node)) if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{ {
axes = const_op->get_axis_set_val(); axes = const_op->get_axis_set_val();
} }
......
...@@ -43,7 +43,7 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -43,7 +43,7 @@ NodeVector op::Squeeze::decompose_op() const
"doesn't support 'axes' input of other type than a Constant."); "doesn't support 'axes' input of other type than a Constant.");
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
......
...@@ -48,7 +48,7 @@ NodeVector op::Unsqueeze::decompose_op() const ...@@ -48,7 +48,7 @@ NodeVector op::Unsqueeze::decompose_op() const
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
......
...@@ -47,7 +47,7 @@ AxisSet op::LRN::get_reduction_axes() const ...@@ -47,7 +47,7 @@ AxisSet op::LRN::get_reduction_axes() const
{ {
AxisSet axes{1}; // channel axis as default AxisSet axes{1}; // channel axis as default
auto axes_input_node = input_value(1).get_node_shared_ptr(); auto axes_input_node = input_value(1).get_node_shared_ptr();
if (auto const_op = dynamic_pointer_cast<op::Constant>(axes_input_node)) if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{ {
axes = const_op->get_axis_set_val(); axes = const_op->get_axis_set_val();
} }
......
...@@ -59,7 +59,7 @@ op::TopK::TopK(const Output<Node>& arg, ...@@ -59,7 +59,7 @@ op::TopK::TopK(const Output<Node>& arg,
size_t op::TopK::get_k() const size_t op::TopK::get_k() const
{ {
size_t k = 0; size_t k = 0;
if (auto const_op = dynamic_pointer_cast<op::Constant>(input_value(1).get_node_shared_ptr())) if (auto const_op = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
k = const_op->get_vector<int64_t>()[0]; k = const_op->get_vector<int64_t>()[0];
} }
......
...@@ -41,13 +41,13 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg, ...@@ -41,13 +41,13 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
bool op::util::ArithmeticReduction::reduction_axes_constant() const bool op::util::ArithmeticReduction::reduction_axes_constant() const
{ {
return dynamic_pointer_cast<op::Constant>(get_argument(1)) != nullptr; return input_value(1).get_node()->is_type<op::Constant>();
} }
const AxisSet op::util::ArithmeticReduction::get_reduction_axes() const const AxisSet op::util::ArithmeticReduction::get_reduction_axes() const
{ {
AxisSet axes; AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_op = as_type<op::Constant>(input_value(1).get_node()))
{ {
axes = const_op->get_axis_set_val(); axes = const_op->get_axis_set_val();
} }
......
...@@ -82,8 +82,10 @@ namespace ngraph ...@@ -82,8 +82,10 @@ namespace ngraph
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool is_binary_elementwise_arithmetic() const override { return true; }
bool supports_auto_broadcast() const override { return true; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -88,8 +88,10 @@ namespace ngraph ...@@ -88,8 +88,10 @@ namespace ngraph
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_comparison() const override { return true; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -84,8 +84,10 @@ namespace ngraph ...@@ -84,8 +84,10 @@ namespace ngraph
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; } void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_logical() const override { return true; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -30,12 +30,7 @@ namespace ngraph ...@@ -30,12 +30,7 @@ namespace ngraph
class FusedOp : public Op class FusedOp : public Op
{ {
public: public:
/// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops bool supports_decompose() const override { return true; }
///
/// \return A vector of nodes comprising the sub-graph. The order of output
/// tensors must match the match output tensors of the FusedOp
virtual NodeVector decompose_op() const = 0;
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// Pre and post validation hooks for op-specific actions /// Pre and post validation hooks for op-specific actions
......
...@@ -40,13 +40,13 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, ...@@ -40,13 +40,13 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
bool op::util::LogicalReduction::reduction_axes_constant() const bool op::util::LogicalReduction::reduction_axes_constant() const
{ {
return dynamic_pointer_cast<op::Constant>(get_argument(1)) != nullptr; return input_value(1).get_node()->is_type<op::Constant>();
} }
const AxisSet op::util::LogicalReduction::get_reduction_axes() const const AxisSet op::util::LogicalReduction::get_reduction_axes() const
{ {
AxisSet axes; AxisSet axes;
if (auto const_op = dynamic_pointer_cast<op::Constant>(get_argument(1))) if (auto const_op = as_type<op::Constant>(input_value(1).get_node()))
{ {
axes = const_op->get_axis_set_val(); axes = const_op->get_axis_set_val();
} }
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; }
}; };
} }
} }
......
...@@ -56,7 +56,7 @@ static shared_ptr<pattern::Matcher> ...@@ -56,7 +56,7 @@ static shared_ptr<pattern::Matcher>
static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher) static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher)
{ {
return dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1)); return static_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
} }
//`simplify_concat` identifies slices-concat sequences //`simplify_concat` identifies slices-concat sequences
...@@ -147,7 +147,7 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -147,7 +147,7 @@ static bool simplify_concat(shared_ptr<Node> n)
} }
// check that no other node uses slices and reshapes // check that no other node uses slices and reshapes
if (auto rcarg = dynamic_pointer_cast<op::Reshape>(carg)) if (auto rcarg = as_type_ptr<op::Reshape>(carg))
{ {
auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape()); auto default_shape = get_default_order(rcarg->get_argument(0)->get_shape());
if (default_shape != rcarg->get_input_order()) if (default_shape != rcarg->get_input_order())
...@@ -316,9 +316,9 @@ static bool simplify_add(shared_ptr<Node> n) ...@@ -316,9 +316,9 @@ static bool simplify_add(shared_ptr<Node> n)
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)` //`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(shared_ptr<Node> n) static bool simplify_log(shared_ptr<Node> n)
{ {
if (auto div = dynamic_pointer_cast<op::Divide>(n->get_argument(0))) if (auto div = as_type_ptr<op::Divide>(n->input_value(0).get_node_shared_ptr()))
{ {
if (auto exp = dynamic_pointer_cast<op::Exp>(div->get_argument(0))) if (auto exp = as_type_ptr<op::Exp>(div->input_value(0).get_node_shared_ptr()))
{ {
auto denom = div->get_argument(1); auto denom = div->get_argument(1);
auto diff = auto diff =
...@@ -417,14 +417,14 @@ static bool simplify_reduction(shared_ptr<Node> n) ...@@ -417,14 +417,14 @@ static bool simplify_reduction(shared_ptr<Node> n)
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name(); NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = static_pointer_cast<T>(n); auto reduction = static_pointer_cast<T>(n);
auto broadcast = dynamic_pointer_cast<op::Broadcast>(n->get_argument(0)); auto broadcast = as_type_ptr<op::Broadcast>(n->input_value(0).get_node_shared_ptr());
if (!broadcast) if (!broadcast)
{ {
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast"; NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false; return false;
} }
auto cnst = dynamic_pointer_cast<op::Constant>(broadcast->get_argument(0)); auto cnst = as_type_ptr<op::Constant>(broadcast->input_value(0).get_node_shared_ptr());
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/) if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{ {
NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant"; NGRAPH_DEBUG << broadcast->get_argument(0)->get_name() << " isn't a scalar constant";
......
...@@ -101,7 +101,7 @@ void pass::ConcatElimination::construct_concat_elimination() ...@@ -101,7 +101,7 @@ void pass::ConcatElimination::construct_concat_elimination()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto op = pattern_map[op_label]; auto op = pattern_map[op_label];
auto root = std::dynamic_pointer_cast<op::Concat>(m.get_match_root()); auto root = as_type_ptr<op::Concat>(m.get_match_root());
if (root && (root->get_input_shape(0) == root->get_output_shape(0))) if (root && (root->get_input_shape(0) == root->get_output_shape(0)))
{ {
NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n"; NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n";
...@@ -122,7 +122,7 @@ bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> f ...@@ -122,7 +122,7 @@ bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> f
bool modify_graph = false; bool modify_graph = false;
auto has_multiple_inputs = [](std::shared_ptr<Node> n) { auto has_multiple_inputs = [](std::shared_ptr<Node> n) {
auto input_size = n->get_input_size(); auto input_size = n->get_input_size();
auto root = std::dynamic_pointer_cast<op::Concat>(n); auto root = as_type_ptr<op::Concat>(n);
return (root && input_size > 1); return (root && input_size > 1);
}; };
...@@ -178,7 +178,7 @@ void ngraph::pass::SelfConcatFusion::construct_concat_patterns( ...@@ -178,7 +178,7 @@ void ngraph::pass::SelfConcatFusion::construct_concat_patterns(
if (matcher->match(n)) if (matcher->match(n))
{ {
auto concat_op = matcher->get_pattern_map()[concat_op_label]; auto concat_op = matcher->get_pattern_map()[concat_op_label];
if (!std::dynamic_pointer_cast<op::Concat>(concat_op)) if (!concat_op->is_type<op::Concat>())
{ {
NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched " NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched "
<< concat_op->get_name() << " instead of a self concat"; << concat_op->get_name() << " instead of a self concat";
......
...@@ -35,7 +35,7 @@ static shared_ptr<op::Constant> ...@@ -35,7 +35,7 @@ static shared_ptr<op::Constant>
{ {
vector<T> out_vec(shape_size(reduction_node->get_shape())); vector<T> out_vec(shape_size(reduction_node->get_shape()));
if (auto max = dynamic_pointer_cast<op::Max>(reduction_node)) if (auto max = as_type_ptr<op::Max>(reduction_node))
{ {
runtime::reference::max<T>(constant->get_vector<T>().data(), runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
...@@ -43,7 +43,7 @@ static shared_ptr<op::Constant> ...@@ -43,7 +43,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(), reduction_node->get_shape(),
max->get_reduction_axes()); max->get_reduction_axes());
} }
else if (auto min = dynamic_pointer_cast<op::Min>(reduction_node)) else if (auto min = as_type_ptr<op::Min>(reduction_node))
{ {
runtime::reference::min<T>(constant->get_vector<T>().data(), runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
...@@ -51,7 +51,7 @@ static shared_ptr<op::Constant> ...@@ -51,7 +51,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(), reduction_node->get_shape(),
min->get_reduction_axes()); min->get_reduction_axes());
} }
else if (auto prod = dynamic_pointer_cast<op::Product>(reduction_node)) else if (auto prod = as_type_ptr<op::Product>(reduction_node))
{ {
runtime::reference::product<T>(constant->get_vector<T>().data(), runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
...@@ -59,7 +59,7 @@ static shared_ptr<op::Constant> ...@@ -59,7 +59,7 @@ static shared_ptr<op::Constant>
reduction_node->get_shape(), reduction_node->get_shape(),
prod->get_reduction_axes()); prod->get_reduction_axes());
} }
else if (auto sum = dynamic_pointer_cast<op::Sum>(reduction_node)) else if (auto sum = as_type_ptr<op::Sum>(reduction_node))
{ {
runtime::reference::sum<T>(constant->get_vector<T>().data(), runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
......
This diff is collapsed.
...@@ -57,14 +57,13 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -57,14 +57,13 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant]; auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match); auto dequantize_op = as_type_ptr<op::Dequantize>(dequant_match);
auto scale = dynamic_pointer_cast<op::Constant>( auto scale = as_type_ptr<op::Constant>(dequant_match->input_value(1).get_node_shared_ptr());
dequant_match->input(1).get_source_output().get_node_shared_ptr()); auto offset =
auto offset = dynamic_pointer_cast<op::Constant>( as_type_ptr<op::Constant>(dequant_match->input_value(2).get_node_shared_ptr());
dequant_match->input(2).get_source_output().get_node_shared_ptr());
NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op)); NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op));
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
......
...@@ -28,7 +28,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -28,7 +28,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
{ {
vector<char> out_vec(shape_size(reduction_node->get_shape())); vector<char> out_vec(shape_size(reduction_node->get_shape()));
if (auto all = dynamic_pointer_cast<::ngraph::op::All>(reduction_node)) if (auto all = as_type_ptr<::ngraph::op::All>(reduction_node))
{ {
runtime::reference::all(constant->get_vector<char>().data(), runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(), out_vec.data(),
...@@ -36,7 +36,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C ...@@ -36,7 +36,7 @@ static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::C
reduction_node->get_shape(), reduction_node->get_shape(),
all->get_reduction_axes()); all->get_reduction_axes());
} }
else if (auto any = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node)) else if (auto any = as_type_ptr<::ngraph::op::Any>(reduction_node))
{ {
runtime::reference::any(constant->get_vector<char>().data(), runtime::reference::any(constant->get_vector<char>().data(),
out_vec.data(), out_vec.data(),
......
...@@ -59,9 +59,9 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -59,9 +59,9 @@ void pass::ConstantFolding::construct_constant_quantize()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto quant_match = pattern_map[quant]; auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match); auto quantize_op = as_type_ptr<op::Quantize>(quant_match);
NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op)); NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op));
......
...@@ -40,10 +40,9 @@ using namespace ngraph; ...@@ -40,10 +40,9 @@ using namespace ngraph;
bool is_supported_unary_op(std::shared_ptr<Node> n) bool is_supported_unary_op(std::shared_ptr<Node> n)
{ {
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) || return n->is_type<op::Abs>() || n->is_type<op::Ceiling>() || n->is_type<op::Floor>() ||
std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) || n->is_type<op::Negative>() || n->is_type<op::Not>() || n->is_type<op::Relu>() ||
std::dynamic_pointer_cast<op::Not>(n) || std::dynamic_pointer_cast<op::Relu>(n) || n->is_type<op::Sign>() || n->is_type<op::Sqrt>();
std::dynamic_pointer_cast<op::Sign>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
} }
template <class T> template <class T>
...@@ -52,7 +51,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -52,7 +51,7 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
NodeExecutorTy func) NodeExecutorTy func)
{ {
// check sqrt arg // check sqrt arg
if (std::dynamic_pointer_cast<op::Sqrt>(unary)) if (unary->is_type<op::Sqrt>())
{ {
std::vector<T> values{constant->get_vector<T>()}; std::vector<T> values{constant->get_vector<T>()};
if (std::any_of(values.begin(), values.end(), [](T i) { return i < T(0); })) if (std::any_of(values.begin(), values.end(), [](T i) { return i < T(0); }))
...@@ -75,42 +74,42 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -75,42 +74,42 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
} }
else else
{ {
if (std::dynamic_pointer_cast<op::Abs>(unary)) if (unary->is_type<op::Abs>())
{ {
runtime::reference::abs<T>( runtime::reference::abs<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Ceiling>(unary)) else if (unary->is_type<op::Ceiling>())
{ {
runtime::reference::ceiling<T>( runtime::reference::ceiling<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Floor>(unary)) else if (unary->is_type<op::Floor>())
{ {
runtime::reference::floor<T>( runtime::reference::floor<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Negative>(unary)) else if (unary->is_type<op::Negative>())
{ {
runtime::reference::negate<T>( runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Not>(unary)) else if (unary->is_type<op::Not>())
{ {
runtime::reference::logical_not<T>( runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Relu>(unary)) else if (unary->is_type<op::Relu>())
{ {
runtime::reference::relu<T>( runtime::reference::relu<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Sign>(unary)) else if (unary->is_type<op::Sign>())
{ {
runtime::reference::sign<T>( runtime::reference::sign<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Sqrt>(unary)) else if (unary->is_type<op::Sqrt>())
{ {
runtime::reference::sqrt<T>( runtime::reference::sqrt<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
...@@ -129,8 +128,7 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -129,8 +128,7 @@ void pass::ConstantFolding::construct_constant_unary()
auto constant_label = make_shared<pattern::op::Label>( auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_ue = [](std::shared_ptr<Node> n) { auto is_ue = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::util::UnaryElementwiseArithmetic>()(n) || return n->is_unary_elementwise_arithmetic() || pattern::has_class<op::Not>()(n);
pattern::has_class<op::Not>()(n));
}; };
auto ue = std::make_shared<pattern::op::Any>(constant_label, is_ue, NodeVector{constant_label}); auto ue = std::make_shared<pattern::op::Any>(constant_label, is_ue, NodeVector{constant_label});
...@@ -140,7 +138,7 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -140,7 +138,7 @@ void pass::ConstantFolding::construct_constant_unary()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
auto unary_match = m.get_match_root(); auto unary_match = m.get_match_root();
if (!is_supported_unary_op(unary_match)) if (!is_supported_unary_op(unary_match))
......
...@@ -570,7 +570,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -570,7 +570,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
{ {
if (is_used(n.get())) if (is_used(n.get()))
{ {
if (dynamic_pointer_cast<op::Convolution>(n) == nullptr) if (!n->is_type<op::Convolution>())
{ {
NGRAPH_DEBUG << "Not all live users of element wise operation are Convolution"; NGRAPH_DEBUG << "Not all live users of element wise operation are Convolution";
return false; return false;
...@@ -821,16 +821,15 @@ void pass::CoreFusion::construct_zero_padded_reshaped_conv() ...@@ -821,16 +821,15 @@ void pass::CoreFusion::construct_zero_padded_reshaped_conv()
pattern::Matcher& m) { pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]); auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op) if (!pad_value_op)
{ {
NGRAPH_DEBUG << "Pad value must be a constant"; NGRAPH_DEBUG << "Pad value must be a constant";
return false; return false;
} }
const auto& matched_conv = const auto& matched_conv = as_type_ptr<ngraph::op::Convolution>(pattern_map[conv_label]);
std::static_pointer_cast<ngraph::op::Convolution>(pattern_map[conv_label]); const auto& matched_pad = as_type_ptr<ngraph::op::Pad>(pattern_map[pad_label]);
const auto& matched_pad = std::static_pointer_cast<ngraph::op::Pad>(pattern_map[pad_label]);
const auto& matched_reshape = const auto& matched_reshape =
std::static_pointer_cast<ngraph::op::Reshape>(pattern_map[reshape_label]); std::static_pointer_cast<ngraph::op::Reshape>(pattern_map[reshape_label]);
...@@ -905,7 +904,7 @@ void pass::CoreFusion::construct_zero_padded_conv() ...@@ -905,7 +904,7 @@ void pass::CoreFusion::construct_zero_padded_conv()
pattern::Matcher& m) { pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]); auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op) if (!pad_value_op)
{ {
NGRAPH_DEBUG << "Pad value must be a constant"; NGRAPH_DEBUG << "Pad value must be a constant";
...@@ -976,7 +975,7 @@ void pass::CoreFusion::construct_zero_padded_conv_backprop_filters() ...@@ -976,7 +975,7 @@ void pass::CoreFusion::construct_zero_padded_conv_backprop_filters()
pattern::Matcher& m) { pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto pad_value_op = std::dynamic_pointer_cast<ngraph::op::Constant>(pattern_map[pad_value]); auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
if (!pad_value_op) if (!pad_value_op)
{ {
NGRAPH_DEBUG << "Pad value must be a constant"; NGRAPH_DEBUG << "Pad value must be a constant";
...@@ -1036,7 +1035,7 @@ void pass::CoreFusion::construct_conv_bias() ...@@ -1036,7 +1035,7 @@ void pass::CoreFusion::construct_conv_bias()
auto pbcast = make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3}); auto pbcast = make_shared<op::Broadcast>(pbias, shape, AxisSet{0, 1, 2, 3});
auto pbcast_label = make_shared<pattern::op::Label>(pbcast, nullptr, NodeVector{pbcast}); auto pbcast_label = make_shared<pattern::op::Label>(pbcast, nullptr, NodeVector{pbcast});
auto reshape_pred = [](shared_ptr<Node> node) -> bool { auto reshape_pred = [](shared_ptr<Node> node) -> bool {
if (auto reshape = dynamic_pointer_cast<op::Reshape>(node)) if (auto reshape = as_type_ptr<op::Reshape>(node))
{ {
auto ishape = reshape->get_input_shape(0); auto ishape = reshape->get_input_shape(0);
auto oshape = reshape->get_shape(); auto oshape = reshape->get_shape();
...@@ -1066,7 +1065,7 @@ void pass::CoreFusion::construct_conv_bias() ...@@ -1066,7 +1065,7 @@ void pass::CoreFusion::construct_conv_bias()
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv_m = dynamic_pointer_cast<op::Convolution>(m.get_match_root()->get_argument(0)); auto conv_m = as_type_ptr<op::Convolution>(m.get_match_root()->get_argument(0));
if (conv_m == nullptr) if (conv_m == nullptr)
{ {
...@@ -1138,7 +1137,7 @@ void pass::CoreFusion::construct_conv_bias_add() ...@@ -1138,7 +1137,7 @@ void pass::CoreFusion::construct_conv_bias_add()
auto add_m = m.get_match_root(); auto add_m = m.get_match_root();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto conv_m = dynamic_pointer_cast<op::ConvolutionBias>(add_m->get_argument(1)); auto conv_m = as_type_ptr<op::ConvolutionBias>(add_m->get_argument(1));
auto add_input_m = add_m->get_argument(0); auto add_input_m = add_m->get_argument(0);
if (!conv_m) if (!conv_m)
......
...@@ -29,7 +29,7 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -29,7 +29,7 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{ {
bool modified = false; bool modified = false;
if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node)) if (node->supports_decompose())
{ {
if (m_has_direct_support && m_has_direct_support(*node)) if (m_has_direct_support && m_has_direct_support(*node))
{ {
...@@ -37,9 +37,9 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -37,9 +37,9 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
return modified; return modified;
} }
auto subgraph_outputs = fused_op->decompose_op(); auto subgraph_outputs = node->decompose_op();
// Run recursively untill no more fused ops // Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments()); auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments());
for (auto subgraph_node : subgraph) for (auto subgraph_node : subgraph)
{ {
run_on_node(subgraph_node); run_on_node(subgraph_node);
...@@ -51,12 +51,11 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -51,12 +51,11 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++) for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{ {
// TODO: Provenance // TODO: Provenance
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()), set<descriptor::Input*> fop_users{begin(node->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())}; end(node->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users) for (auto fop_user : fop_users)
{ {
if (auto goe = if (auto goe = as_type<op::GetOutputElement>(fop_user->get_raw_pointer_node()))
dynamic_cast<op::GetOutputElement*>(fop_user->get_raw_pointer_node()))
{ {
Output<Node> goe_output = goe->get_as_output(); Output<Node> goe_output = goe->get_as_output();
if (goe_output.get_index() == i && !goe->get_output_inputs(0).empty()) if (goe_output.get_index() == i && !goe->get_output_inputs(0).empty())
...@@ -78,12 +77,12 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -78,12 +77,12 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
} }
} }
} }
if (i != fused_op->get_output_size()) if (i != node->get_output_size())
{ {
throw ngraph_error("While replacing " + node->get_name() + throw ngraph_error("While replacing " + node->get_name() +
", mismatch between op output count and outputs of the decomposed " ", mismatch between op output count and outputs of the decomposed "
"subgraph. Expected: " + "subgraph. Expected: " +
to_string(fused_op->get_output_size()) + " Got: " + to_string(i)); to_string(node->get_output_size()) + " Got: " + to_string(i));
} }
modified = true; modified = true;
} }
......
...@@ -24,27 +24,19 @@ ...@@ -24,27 +24,19 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template <typename optype> bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<ngraph::Node> node)
static bool broadcast_and_replace(std::shared_ptr<ngraph::Node>& node)
{ {
if (auto op = std::dynamic_pointer_cast<optype>(node)) if (node->supports_auto_broadcast())
{ {
if (op->get_autob().m_type != op::AutoBroadcastType::NONE) if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
auto new_args = pass::explicit_broadcast<optype>(op); auto new_args = pass::explicit_broadcast(node);
for (size_t i = 0; i < new_args.size(); i++) for (size_t i = 0; i < new_args.size(); i++)
{ {
op->input(i).replace_source_output(new_args[i]->output(0)); node->input(i).replace_source_output(new_args[i]->output(0));
} }
return true; return true;
} }
} }
return false; return false;
} }
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<ngraph::Node> node)
{
return broadcast_and_replace<op::util::BinaryElementwiseArithmetic>(node) ||
broadcast_and_replace<op::util::BinaryElementwiseComparison>(node) ||
broadcast_and_replace<op::util::BinaryElementwiseLogical>(node);
}
...@@ -23,22 +23,23 @@ namespace ngraph ...@@ -23,22 +23,23 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
template <typename T> NodeVector explicit_broadcast(std::shared_ptr<Node>& node)
NodeVector explicit_broadcast(std::shared_ptr<T>& node)
{ {
NodeVector rc; NodeVector rc;
if (node->supports_auto_broadcast())
if (node->get_autob().m_type == op::AutoBroadcastType::NONE)
{
rc = node->get_arguments();
}
else if (node->get_autob().m_type == op::AutoBroadcastType::NUMPY)
{
rc = op::numpy_style_broadcast(node->get_arguments());
}
else
{ {
throw ngraph_error("Unsupported implicit broadcast type"); if (node->get_autob().m_type == op::AutoBroadcastType::NONE)
{
rc = node->get_arguments();
}
else if (node->get_autob().m_type == op::AutoBroadcastType::NUMPY)
{
rc = op::numpy_style_broadcast(node->get_arguments());
}
else
{
throw ngraph_error("Unsupported implicit broadcast type");
}
} }
return rc; return rc;
} }
......
...@@ -66,7 +66,7 @@ bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function) ...@@ -66,7 +66,7 @@ bool pass::LikeReplacement::run_on_function(shared_ptr<Function> function)
// Here we're checking on a common base class of a family of template classes, // Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle. // which is more than type info can handle.
auto sclb = dynamic_pointer_cast<op::ScalarConstantLikeBase>(n); auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr) if (sclb != nullptr)
{ {
replace_node(sclb, sclb->as_constant()); replace_node(sclb, sclb->as_constant());
......
...@@ -58,7 +58,7 @@ bool pass::Liveness::run_on_function(shared_ptr<Function> function) ...@@ -58,7 +58,7 @@ bool pass::Liveness::run_on_function(shared_ptr<Function> function)
} }
for (const shared_ptr<Node>& node : ops) for (const shared_ptr<Node>& node : ops)
{ {
if (auto constant_node = dynamic_pointer_cast<op::Constant>(node)) if (auto constant_node = as_type_ptr<op::Constant>(node))
{ {
for (auto& output : constant_node->outputs()) for (auto& output : constant_node->outputs())
{ {
......
...@@ -52,8 +52,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function) ...@@ -52,8 +52,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
{ {
auto op = std::static_pointer_cast<op::Op>(node); auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently // concat and slice in_place_oi should be treated differently
if (!std::dynamic_pointer_cast<op::Concat>(node) && if (!node->is_type<op::Concat>() && !node->is_type<op::Slice>())
!std::dynamic_pointer_cast<op::Slice>(node))
{ {
if (auto op_annotations = op->get_op_annotations()) if (auto op_annotations = op->get_op_annotations())
{ {
...@@ -66,7 +65,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function) ...@@ -66,7 +65,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
// For destructive kernel, this should be the last use // For destructive kernel, this should be the last use
// Non-destructive kernels can pass through if memory sharing is disabled // Non-destructive kernels can pass through if memory sharing is disabled
if ((node->liveness_free_list.count(input) != 0 || if ((node->liveness_free_list.count(input) != 0 ||
std::dynamic_pointer_cast<op::GetOutputElement>(node) || node->is_type<op::GetOutputElement>() ||
(m_disable_memory_sharing && !oi_pair.destructive && (m_disable_memory_sharing && !oi_pair.destructive &&
!input_node->is_parameter() && !input_node->is_constant())) && !input_node->is_parameter() && !input_node->is_constant())) &&
node->liveness_new_list.count(output) != 0) node->liveness_new_list.count(output) != 0)
......
...@@ -133,7 +133,7 @@ bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) ...@@ -133,7 +133,7 @@ bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
// Here we're checking on a common base class of a family of template classes, // Here we're checking on a common base class of a family of template classes,
// which is more than type info can handle. // which is more than type info can handle.
auto sclb = std::dynamic_pointer_cast<op::ScalarConstantLikeBase>(n); auto sclb = as_type_ptr<op::ScalarConstantLikeBase>(n);
if (sclb != nullptr) if (sclb != nullptr)
{ {
replace_node(sclb, sclb->as_constant()); replace_node(sclb, sclb->as_constant());
......
...@@ -52,7 +52,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -52,7 +52,7 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op]; auto gop = pattern_map[op];
auto r1 = dynamic_pointer_cast<op::Reshape>(m.get_match_root()); auto r1 = as_type_ptr<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape()) if (r1->get_shape() != gop->get_shape())
{ {
...@@ -152,9 +152,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -152,9 +152,7 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
void pass::ReshapeElimination::construct_dot_transpose_pattern() void pass::ReshapeElimination::construct_dot_transpose_pattern()
{ {
// dot(A,B).T = dot (B.T, A.T) // dot(A,B).T = dot (B.T, A.T)
auto dot_pred = [](shared_ptr<Node> n) { auto dot_pred = [](shared_ptr<Node> n) { return n->is_type<op::Dot>(); };
return static_cast<bool>(dynamic_pointer_cast<op::Dot>(n));
};
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred); auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2}); auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
...@@ -232,7 +230,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -232,7 +230,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
// Need to check if the user of the last bound op is a reshape since the last reshape is // Need to check if the user of the last bound op is a reshape since the last reshape is
// allowed to have fan-out but the matcher will discard any reshape if it has fan-out // allowed to have fan-out but the matcher will discard any reshape if it has fan-out
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0]; auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
if (std::dynamic_pointer_cast<op::Reshape>(user_of_last_bound_reshape_op)) if (user_of_last_bound_reshape_op->is_type<op::Reshape>())
{ {
reshape_node_vector.push_back(user_of_last_bound_reshape_op); reshape_node_vector.push_back(user_of_last_bound_reshape_op);
last_bound_reshape_op = reshape_node_vector.back(); last_bound_reshape_op = reshape_node_vector.back();
...@@ -251,7 +249,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -251,7 +249,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end(); for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end();
it++) it++)
{ {
auto r = std::dynamic_pointer_cast<op::Reshape>(*it); auto r = as_type_ptr<op::Reshape>(*it);
// Check that the input to r is the last reshape stored in the // Check that the input to r is the last reshape stored in the
// subpattern vector // subpattern vector
...@@ -286,9 +284,9 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -286,9 +284,9 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
continue; continue;
} }
auto first_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.front()); auto first_reshape = as_type_ptr<op::Reshape>(sub_pattern.front());
auto input_to_first_reshape = first_reshape->get_argument(0); auto input_to_first_reshape = first_reshape->get_argument(0);
auto last_reshape = std::dynamic_pointer_cast<op::Reshape>(sub_pattern.back()); auto last_reshape = as_type_ptr<op::Reshape>(sub_pattern.back());
auto new_input_order = first_reshape->get_input_order(); auto new_input_order = first_reshape->get_input_order();
auto new_out_shape = last_reshape->get_shape(); auto new_out_shape = last_reshape->get_shape();
......
...@@ -47,7 +47,7 @@ using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>; ...@@ -47,7 +47,7 @@ using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>;
static string describe_reshape(shared_ptr<Node> node) static string describe_reshape(shared_ptr<Node> node)
{ {
stringstream ss; stringstream ss;
auto reshape = dynamic_pointer_cast<op::Reshape>(node); auto reshape = as_type_ptr<op::Reshape>(node);
ss << reshape->get_name() ss << reshape->get_name()
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order()) << " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) " << " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
...@@ -167,14 +167,14 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape) ...@@ -167,14 +167,14 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
work_queue.pop_front(); work_queue.pop_front();
auto n = csw.input.get_source_output().get_node_shared_ptr(); auto n = csw.input.get_source_output().get_node_shared_ptr();
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name(); NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n)) if (n->is_unary_elementwise_arithmetic())
{ {
Swimmer nsw{unary->input(0), csw.reshape}; Swimmer nsw{n->input(0), csw.reshape};
work_queue.push_back(nsw); work_queue.push_back(nsw);
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for " NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
<< n->get_name() << " to " << unary->get_argument(0); << n->get_name() << " to " << n->get_argument(0);
} }
else if (dynamic_pointer_cast<op::Broadcast>(n)) else if (n->is_type<op::Broadcast>())
{ {
auto old_broadcast = static_pointer_cast<op::Broadcast>(n); auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
auto broadcast_axes = old_broadcast->get_broadcast_axes(); auto broadcast_axes = old_broadcast->get_broadcast_axes();
...@@ -324,7 +324,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape, ...@@ -324,7 +324,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
} }
} }
static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n, static void sink_unary(shared_ptr<Node> n,
ReshapeMap& reorders, ReshapeMap& reorders,
set<shared_ptr<Node>>& /* reshapes_to_delete */) set<shared_ptr<Node>>& /* reshapes_to_delete */)
{ {
...@@ -333,7 +333,7 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n, ...@@ -333,7 +333,7 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
write_reshapemap(reorders, n, arg_reshape); write_reshapemap(reorders, n, arg_reshape);
} }
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary, static void sink_binary(shared_ptr<Node> binary,
ReshapeMap& reorders, ReshapeMap& reorders,
set<shared_ptr<Node>>& reshapes_to_delete) set<shared_ptr<Node>>& reshapes_to_delete)
{ {
...@@ -533,31 +533,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -533,31 +533,31 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
results.push_back(n); results.push_back(n);
} }
if (auto reshape = dynamic_pointer_cast<op::Reshape>(n)) if (auto reshape = as_type_ptr<op::Reshape>(n))
{ {
sink_reshape(reshape, reorders, reshapes_to_delete); sink_reshape(reshape, reorders, reshapes_to_delete);
} }
else if (auto unary = dynamic_pointer_cast<op::util::UnaryElementwiseArithmetic>(n)) else if (n->is_unary_elementwise_arithmetic())
{ {
sink_unary(unary, reorders, reshapes_to_delete); sink_unary(n, reorders, reshapes_to_delete);
} }
else if (auto binary = dynamic_pointer_cast<op::util::BinaryElementwiseArithmetic>(n)) else if (n->is_binary_elementwise_arithmetic())
{ {
sink_binary(binary, reorders, reshapes_to_delete); sink_binary(n, reorders, reshapes_to_delete);
} }
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n)) else if (auto goe = as_type_ptr<op::GetOutputElement>(n))
{ {
write_reshapemap(reorders, goe, create_default_reshape(goe)); write_reshapemap(reorders, goe, create_default_reshape(goe));
} }
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n)) else if (auto quantize = as_type_ptr<op::Quantize>(n))
{ {
sink_quantize(quantize, reorders, reshapes_to_delete); sink_quantize(quantize, reorders, reshapes_to_delete);
} }
else if (auto dequantize = dynamic_pointer_cast<op::Dequantize>(n)) else if (auto dequantize = as_type_ptr<op::Dequantize>(n))
{ {
sink_dequantize(dequantize, reorders, reshapes_to_delete); sink_dequantize(dequantize, reorders, reshapes_to_delete);
} }
else if (auto slice = dynamic_pointer_cast<op::Slice>(n)) else if (auto slice = as_type_ptr<op::Slice>(n))
{ {
// A heuristic. If Reshape has multiple slice users, if sunk // A heuristic. If Reshape has multiple slice users, if sunk
// it will be replicated by the number of its users // it will be replicated by the number of its users
...@@ -578,11 +578,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> ...@@ -578,11 +578,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
materialize_shapes(n, reorders, reshapes_to_delete); materialize_shapes(n, reorders, reshapes_to_delete);
} }
} }
else if (auto pad = dynamic_pointer_cast<op::Pad>(n)) else if (auto pad = as_type_ptr<op::Pad>(n))
{ {
sink_pad(pad, reorders, reshapes_to_delete); sink_pad(pad, reorders, reshapes_to_delete);
} }
else if (auto concat = dynamic_pointer_cast<op::Concat>(n)) else if (auto concat = as_type_ptr<op::Concat>(n))
{ {
sink_concat(concat, reorders, reshapes_to_delete); sink_concat(concat, reorders, reshapes_to_delete);
} }
......
...@@ -33,7 +33,7 @@ void pass::ValidateGraph::validate_parameters(const Function& function) ...@@ -33,7 +33,7 @@ void pass::ValidateGraph::validate_parameters(const Function& function)
auto parameters = function.get_parameters(); auto parameters = function.get_parameters();
for (auto node : function.get_ops()) for (auto node : function.get_ops())
{ {
shared_ptr<op::Parameter> p = dynamic_pointer_cast<op::Parameter>(node); shared_ptr<op::Parameter> p = as_type_ptr<op::Parameter>(node);
if (nullptr != p) if (nullptr != p)
{ {
auto it = find_if(parameters.begin(), auto it = find_if(parameters.begin(),
......
...@@ -163,7 +163,7 @@ static std::string label_edge(const std::shared_ptr<Node>& /* src */, ...@@ -163,7 +163,7 @@ static std::string label_edge(const std::shared_ptr<Node>& /* src */,
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr) if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
{ {
size_t output = 0; size_t output = 0;
if (auto goe = dst->as_type<op::GetOutputElement>()) if (auto goe = as_type_ptr<op::GetOutputElement>(dst))
{ {
output = goe->get_as_output().get_index(); output = goe->get_as_output().get_index();
} }
...@@ -223,7 +223,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions) ...@@ -223,7 +223,7 @@ bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
traverse_nodes(f, [&](shared_ptr<Node> node) { traverse_nodes(f, [&](shared_ptr<Node> node) {
if (auto ck = dynamic_pointer_cast<ngraph::op::CompiledKernel>(node)) if (auto ck = as_type_ptr<ngraph::op::CompiledKernel>(node))
{ {
// print sub-graph // print sub-graph
auto nodes_list = ck->get_node_list(); auto nodes_list = ck->get_node_list();
...@@ -416,7 +416,7 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node) ...@@ -416,7 +416,7 @@ string pass::VisualizeTree::get_node_name(shared_ptr<Node> node)
{ {
rc += "\\n" + node->get_name(); rc += "\\n" + node->get_name();
} }
if (auto ck = node->as_type<ngraph::op::CompiledKernel>()) if (auto ck = as_type_ptr<ngraph::op::CompiledKernel>(node))
{ {
rc += "\\n{"; rc += "\\n{";
// add sub-graph node names // add sub-graph node names
......
...@@ -111,7 +111,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f) ...@@ -111,7 +111,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
continue; continue;
} }
if (auto concat = dynamic_pointer_cast<op::Concat>(n)) if (auto concat = as_type_ptr<op::Concat>(n))
{ {
NodeVector non_zero_dim_args; NodeVector non_zero_dim_args;
for (auto arg : concat->get_arguments()) for (auto arg : concat->get_arguments())
......
...@@ -28,6 +28,11 @@ namespace ngraph ...@@ -28,6 +28,11 @@ namespace ngraph
{ {
namespace pattern namespace pattern
{ {
constexpr NodeTypeInfo op::AnyOf::type_info;
constexpr NodeTypeInfo op::Any::type_info;
constexpr NodeTypeInfo op::Label::type_info;
constexpr NodeTypeInfo op::Skip::type_info;
std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; } std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label, bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
...@@ -227,23 +232,23 @@ namespace ngraph ...@@ -227,23 +232,23 @@ namespace ngraph
<< "pattern = " << pattern_node->get_name() << " matched " << "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name(); << graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node)) if (auto label_node = as_type_ptr<op::Label>(pattern_node))
{ {
return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map)); return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map));
} }
if (auto skip_node = std::dynamic_pointer_cast<op::Skip>( if (auto skip_node =
pattern_node)) // matches PatternSkipOp semantics as_type_ptr<op::Skip>(pattern_node)) // matches PatternSkipOp semantics
{ {
return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map)); return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map));
} }
if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node)) if (auto any_node = as_type_ptr<op::Any>(pattern_node))
{ {
return abort_match(watermark, match_any(any_node, graph_node, pattern_map)); return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
} }
if (auto any_of_node = std::dynamic_pointer_cast<op::AnyOf>(pattern_node)) if (auto any_of_node = as_type_ptr<op::AnyOf>(pattern_node))
{ {
return abort_match(watermark, match_any_of(any_of_node, graph_node, pattern_map)); return abort_match(watermark, match_any_of(any_of_node, graph_node, pattern_map));
} }
......
...@@ -40,9 +40,7 @@ namespace ngraph ...@@ -40,9 +40,7 @@ namespace ngraph
template <typename T> template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class() std::function<bool(std::shared_ptr<Node>)> has_class()
{ {
auto pred = [](std::shared_ptr<Node> node) -> bool { auto pred = [](std::shared_ptr<Node> node) -> bool { return node->is_type<T>(); };
return std::dynamic_pointer_cast<T>(node) != nullptr;
};
return pred; return pred;
} }
...@@ -109,7 +107,7 @@ namespace ngraph ...@@ -109,7 +107,7 @@ namespace ngraph
std::shared_ptr<T> matched; std::shared_ptr<T> matched;
for (auto arg : node->get_arguments()) for (auto arg : node->get_arguments())
{ {
if (auto t_casted = std::dynamic_pointer_cast<T>(arg)) if (auto t_casted = as_type_ptr<T>(arg))
{ {
if (matched) if (matched)
{ {
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class Any : public Pattern class Any : public Pattern
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternAny", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa /// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
/// shape. /// shape.
Any(const element::Type& type, Any(const element::Type& type,
...@@ -53,12 +56,6 @@ namespace ngraph ...@@ -53,12 +56,6 @@ namespace ngraph
wrapped_nodes) wrapped_nodes)
{ {
} }
const std::string& description() const override
{
static std::string desc = "Any";
return desc;
}
}; };
} }
} }
......
...@@ -35,6 +35,9 @@ namespace ngraph ...@@ -35,6 +35,9 @@ namespace ngraph
class AnyOf : public Pattern class AnyOf : public Pattern
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and /// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
/// \sa shape. /// \sa shape.
AnyOf(const element::Type& type, AnyOf(const element::Type& type,
...@@ -64,12 +67,6 @@ namespace ngraph ...@@ -64,12 +67,6 @@ namespace ngraph
wrapped_nodes) wrapped_nodes)
{ {
} }
const std::string& description() const override
{
static std::string desc = "AnyOf";
return desc;
}
}; };
} }
} }
......
...@@ -31,6 +31,9 @@ namespace ngraph ...@@ -31,6 +31,9 @@ namespace ngraph
class Label : public Pattern class Label : public Pattern
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief creates a Label node containing a sub-pattern described by \sa type and /// \brief creates a Label node containing a sub-pattern described by \sa type and
/// \sa shape. /// \sa shape.
/// ///
...@@ -74,12 +77,6 @@ namespace ngraph ...@@ -74,12 +77,6 @@ namespace ngraph
wrapped_nodes) wrapped_nodes)
{ {
} }
const std::string& description() const override
{
static std::string desc = "Label";
return desc;
}
}; };
} }
} }
......
...@@ -31,17 +31,14 @@ namespace ngraph ...@@ -31,17 +31,14 @@ namespace ngraph
class Skip : public Pattern class Skip : public Pattern
{ {
public: public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr) Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern(NodeVector{arg}, predicate) : Pattern(NodeVector{arg}, predicate)
{ {
set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0)); set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
} }
const std::string& description() const override
{
static std::string desc = "Skip";
return desc;
}
}; };
} }
} }
......
...@@ -31,14 +31,14 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast() ...@@ -31,14 +31,14 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; }); element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
auto broadcast_op = std::make_shared<ngraph::op::Broadcast>(src_op, Shape{}, AxisSet{}); auto broadcast_op = std::make_shared<ngraph::op::Broadcast>(src_op, Shape{}, AxisSet{});
auto target_op = std::make_shared<pattern::op::AnyOf>( auto target_op =
element::i8, std::make_shared<pattern::op::AnyOf>(element::i8,
Shape{}, Shape{},
[](std::shared_ptr<Node> node) { [](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) || return node->is_unary_elementwise_arithmetic() ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node); node->is_binary_elementwise_arithmetic();
}, },
NodeVector{broadcast_op}); NodeVector{broadcast_op});
auto callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we // Since the broadcast is going to an elementwise operation, we
......
...@@ -34,14 +34,14 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision() ...@@ -34,14 +34,14 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
std::make_shared<pattern::op::Skip>(replicate_op, [](std::shared_ptr<Node> node) { std::make_shared<pattern::op::Skip>(replicate_op, [](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node); return pattern::has_class<plaidml::op::Replicate>()(node);
}); });
auto target_op = std::make_shared<pattern::op::AnyOf>( auto target_op =
element::i8, std::make_shared<pattern::op::AnyOf>(element::i8,
Shape{}, Shape{},
[](std::shared_ptr<Node> node) { [](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::util::UnaryElementwiseArithmetic>()(node) || return node->is_unary_elementwise_arithmetic() ||
pattern::has_class<ngraph::op::util::BinaryElementwiseArithmetic>()(node); node->is_binary_elementwise_arithmetic();
}, },
NodeVector{skip_op}); NodeVector{skip_op});
auto callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
bool replaced_any = false; bool replaced_any = false;
......
...@@ -418,7 +418,7 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s ...@@ -418,7 +418,7 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s
traverse_nodes(const_cast<Function*>(func.get()), traverse_nodes(const_cast<Function*>(func.get()),
[&](shared_ptr<Node> node) { [&](shared_ptr<Node> node) {
if (auto c = dynamic_pointer_cast<op::Constant>(node)) if (auto c = node->as_type<op::Constant>())
{ {
uint32_t size = uint32_t size =
static_cast<uint32_t>(shape_size(c->get_output_shape(0)) * static_cast<uint32_t>(shape_size(c->get_output_shape(0)) *
...@@ -624,8 +624,7 @@ ParameterVector JSONDeserializer::deserialize_parameter_vector(json json_paramet ...@@ -624,8 +624,7 @@ ParameterVector JSONDeserializer::deserialize_parameter_vector(json json_paramet
std::vector<std::shared_ptr<op::Parameter>> params; std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : json_parameters) for (auto& param_ref : json_parameters)
{ {
params.push_back( params.push_back(as_type_ptr<op::Parameter>(deserialize_node_reference(param_ref)));
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
} }
return params; return params;
} }
...@@ -646,7 +645,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json func_js) ...@@ -646,7 +645,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json func_js)
for (auto& result_ref : func_result) for (auto& result_ref : func_result)
{ {
auto fr = deserialize_node_reference(result_ref); auto fr = deserialize_node_reference(result_ref);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr)) if (auto res = as_type_ptr<op::Result>(fr))
{ {
result.push_back(res); result.push_back(res);
// make sure we have `op::Result` on top of all outputs // make sure we have `op::Result` on top of all outputs
......
...@@ -70,7 +70,7 @@ std::shared_ptr<Function> ...@@ -70,7 +70,7 @@ std::shared_ptr<Function>
ParameterVector new_parameters = f->get_parameters(); ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++) for (size_t i = 0; i < new_parameters.size(); i++)
{ {
new_parameters[i] = std::dynamic_pointer_cast<op::Parameter>(m[new_parameters[i].get()]); new_parameters[i] = as_type_ptr<op::Parameter>(m[new_parameters[i].get()]);
// If the replacement for a Parameter is not itself a Parameter, we must have replaced it // If the replacement for a Parameter is not itself a Parameter, we must have replaced it
// with a constant. We will insert a dead Parameter into the clone's parameters, in order // with a constant. We will insert a dead Parameter into the clone's parameters, in order
......
...@@ -250,8 +250,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -250,8 +250,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
NodeVector result_nodes; NodeVector result_nodes;
for (auto node : bprop->get_results()) for (auto node : bprop->get_results())
{ {
auto result = auto result = as_type_ptr<op::Result>(fprop_cache.node_param_map.at(node.get()));
std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map.at(node.get()));
if (!result) if (!result)
{ {
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map"); throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
...@@ -266,15 +265,15 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -266,15 +265,15 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ParameterVector bprop_input_params; ParameterVector bprop_input_params;
for (auto param : bprop_inputs) for (auto param : bprop_inputs)
{ {
bprop_input_params.push_back(std::dynamic_pointer_cast<op::Parameter>( bprop_input_params.push_back(
fprop_cache.node_param_map.at(param.get()))); as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(param.get())));
} }
// add the cached fprop nodes as inputs to bprop // add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes) for (auto x : fprop_cache.fprop_output_nodes)
{ {
bprop_input_params.push_back( bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map.at(x))); as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(x)));
} }
return bprop_input_params; return bprop_input_params;
}; };
...@@ -286,7 +285,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -286,7 +285,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ngraph::traverse_nodes( ngraph::traverse_nodes(
result_nodes, result_nodes,
[&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) { [&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) {
auto pnode = std::dynamic_pointer_cast<op::Parameter>(node); auto pnode = as_type_ptr<op::Parameter>(node);
if (pnode != nullptr && if (pnode != nullptr &&
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) == std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
cloned_bprop_inputs.end()) cloned_bprop_inputs.end())
...@@ -302,7 +301,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop, ...@@ -302,7 +301,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto fpirn : fprop_cache.fprop_output_nodes) for (auto fpirn : fprop_cache.fprop_output_nodes)
{ {
auto fpir = fpirn->shared_from_this(); auto fpir = fpirn->shared_from_this();
if (std::dynamic_pointer_cast<op::Result>(fpir)) if (as_type_ptr<op::Result>(fpir))
{ {
throw ngraph_error("Expected op::Result in fprop->get_results()"); throw ngraph_error("Expected op::Result in fprop->get_results()");
} }
......
...@@ -321,7 +321,9 @@ TEST(pattern, matcher) ...@@ -321,7 +321,9 @@ TEST(pattern, matcher)
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>(); auto is_bea = [](std::shared_ptr<Node> node) -> bool {
return node->is_binary_elementwise_arithmetic();
};
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b}); auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto add_ab = a + b; auto add_ab = a + b;
ASSERT_TRUE(n.match(bea, add_ab)); ASSERT_TRUE(n.match(bea, add_ab));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment