Commit 51ad59d3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Make Klockwork Happy Again : Dynamic to Static pointer_casts (ngraph/src/pass, part 1) (#1777)

* dynamic to static casts

* ScalarConstantLikeBase needs a dynamic cast

* revert change
parent c8620b11
...@@ -104,7 +104,7 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -104,7 +104,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
return false; return false;
} }
auto slice = std::dynamic_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]); auto slice = std::static_pointer_cast<op::Slice>(matcher->get_pattern_map()[lslice]);
if (branch_tip) if (branch_tip)
{ {
if (branch_tip != matcher->get_pattern_map()[ltip]) if (branch_tip != matcher->get_pattern_map()[ltip])
...@@ -170,7 +170,7 @@ static bool simplify_concat(std::shared_ptr<Node> n) ...@@ -170,7 +170,7 @@ static bool simplify_concat(std::shared_ptr<Node> n)
} }
} }
auto concat = std::dynamic_pointer_cast<op::Concat>(n); auto concat = std::static_pointer_cast<op::Concat>(n);
size_t concat_axis = concat->get_concatenation_axis(); size_t concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip->get_users().at(0)->get_shape(); auto slice_shape = branch_tip->get_users().at(0)->get_shape();
...@@ -424,7 +424,7 @@ template <typename T, ...@@ -424,7 +424,7 @@ template <typename T,
static bool simplify_reduction(std::shared_ptr<Node> n) static bool simplify_reduction(std::shared_ptr<Node> n)
{ {
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name(); NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
auto reduction = std::dynamic_pointer_cast<T>(n); auto reduction = std::static_pointer_cast<T>(n);
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0)); auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(n->get_argument(0));
if (!broadcast) if (!broadcast)
......
...@@ -71,7 +71,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant, ...@@ -71,7 +71,7 @@ shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant,
{ {
auto out_shape = pad->get_shape(); auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape)); vector<T> out_vec(shape_size(out_shape));
auto pad_value = std::dynamic_pointer_cast<op::Constant>(pad->get_argument(1)); auto pad_value = std::static_pointer_cast<op::Constant>(pad->get_argument(1));
runtime::reference::pad<T>(constant->get_vector<T>().data(), runtime::reference::pad<T>(constant->get_vector<T>().data(),
pad_value->get_vector<T>().data(), pad_value->get_vector<T>().data(),
...@@ -105,8 +105,8 @@ void ngraph::pass::ConstantFolding::construct_constant_pad() ...@@ -105,8 +105,8 @@ void ngraph::pass::ConstantFolding::construct_constant_pad()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto pad_match = dynamic_pointer_cast<op::Pad>(m.get_match_root()); auto pad_match = static_pointer_cast<op::Pad>(m.get_match_root());
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
...@@ -149,8 +149,8 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape() ...@@ -149,8 +149,8 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = dynamic_pointer_cast<op::Reshape>(m.get_match_root()); auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
...@@ -214,8 +214,8 @@ void ngraph::pass::ConstantFolding::construct_constant_broadcast() ...@@ -214,8 +214,8 @@ void ngraph::pass::ConstantFolding::construct_constant_broadcast()
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto broadcast_match = dynamic_pointer_cast<op::Broadcast>(m.get_match_root()); auto broadcast_match = static_pointer_cast<op::Broadcast>(m.get_match_root());
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
if (type == element::i32) if (type == element::i32)
......
...@@ -210,8 +210,8 @@ void pass::CoreFusion::construct_folded_batch_norm() ...@@ -210,8 +210,8 @@ void pass::CoreFusion::construct_folded_batch_norm()
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_bn = std::dynamic_pointer_cast<op::BatchNorm>(m.get_match_root()); auto m_bn = std::static_pointer_cast<op::BatchNorm>(m.get_match_root());
auto m_conv = std::dynamic_pointer_cast<op::Convolution>(m_bn->get_argument(2)); auto m_conv = std::static_pointer_cast<op::Convolution>(m_bn->get_argument(2));
if (m_conv->get_users().size() > 1) if (m_conv->get_users().size() > 1)
{ {
...@@ -397,7 +397,7 @@ static std::shared_ptr<Node> reduce_broadcast(std::shared_ptr<Node> broadcast) ...@@ -397,7 +397,7 @@ static std::shared_ptr<Node> reduce_broadcast(std::shared_ptr<Node> broadcast)
{ {
const size_t H = 2; const size_t H = 2;
const size_t W = 3; const size_t W = 3;
auto matched_broadcast_w1 = std::dynamic_pointer_cast<op::Broadcast>(broadcast); auto matched_broadcast_w1 = std::static_pointer_cast<op::Broadcast>(broadcast);
Shape shape_w1{matched_broadcast_w1->get_shape()}; Shape shape_w1{matched_broadcast_w1->get_shape()};
shape_w1[H] /= 2; shape_w1[H] /= 2;
shape_w1[W] /= 2; shape_w1[W] /= 2;
...@@ -531,7 +531,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -531,7 +531,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
NGRAPH_DEBUG << "element-wise isn't data"; NGRAPH_DEBUG << "element-wise isn't data";
return false; return false;
} }
auto sconv = std::dynamic_pointer_cast<op::Convolution>(sc); auto sconv = std::static_pointer_cast<op::Convolution>(sc);
sparse_shape_index = shape_to_index(sconv->get_shape()); sparse_shape_index = shape_to_index(sconv->get_shape());
if (sparse_shape_index == 0) if (sparse_shape_index == 0)
{ {
...@@ -553,7 +553,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -553,7 +553,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
const size_t full_shape_index = sparse_shape_index - 1; const size_t full_shape_index = sparse_shape_index - 1;
auto m_conv_stride1 = auto m_conv_stride1 =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_stride1_label]); std::static_pointer_cast<op::Convolution>(pattern_map[conv_stride1_label]);
if (!are_img_dims_equal(m_conv_stride1->get_shape(), supported_shapes[full_shape_index]) || if (!are_img_dims_equal(m_conv_stride1->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride1->get_argument(1)->get_shape(), win_size_1) || !are_img_dims_equal(m_conv_stride1->get_argument(1)->get_shape(), win_size_1) ||
...@@ -568,7 +568,7 @@ void pass::CoreFusion::construct_optimized_strided_conv() ...@@ -568,7 +568,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
} }
auto m_conv_stride3 = auto m_conv_stride3 =
std::dynamic_pointer_cast<op::Convolution>(pattern_map[conv_stride3_label]); std::static_pointer_cast<op::Convolution>(pattern_map[conv_stride3_label]);
if (!are_img_dims_equal(m_conv_stride3->get_shape(), supported_shapes[full_shape_index]) || if (!are_img_dims_equal(m_conv_stride3->get_shape(), supported_shapes[full_shape_index]) ||
!are_img_dims_equal(m_conv_stride3->get_argument(1)->get_shape(), shape_3) || !are_img_dims_equal(m_conv_stride3->get_argument(1)->get_shape(), shape_3) ||
......
...@@ -72,8 +72,8 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -72,8 +72,8 @@ static bool cse_constant(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
return false; return false;
} }
auto ca = std::dynamic_pointer_cast<op::Constant>(a); auto ca = std::static_pointer_cast<op::Constant>(a);
auto cb = std::dynamic_pointer_cast<op::Constant>(b); auto cb = std::static_pointer_cast<op::Constant>(b);
size_t size = shape_size(a->get_shape()) * a->get_element_type().size(); size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
...@@ -84,8 +84,8 @@ static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -84,8 +84,8 @@ static bool cse_reshape(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
auto reshape_a = std::dynamic_pointer_cast<ngraph::op::Reshape>(a); auto reshape_a = std::static_pointer_cast<ngraph::op::Reshape>(a);
auto reshape_b = std::dynamic_pointer_cast<ngraph::op::Reshape>(b); auto reshape_b = std::static_pointer_cast<ngraph::op::Reshape>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->get_argument(0) == b->get_argument(0)) &&
(reshape_a->get_input_order() == reshape_b->get_input_order()) && (reshape_a->get_input_order() == reshape_b->get_input_order()) &&
...@@ -95,8 +95,8 @@ static bool cse_broadcast(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -95,8 +95,8 @@ static bool cse_broadcast(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
auto broadcast_a = std::dynamic_pointer_cast<ngraph::op::Broadcast>(a); auto broadcast_a = std::static_pointer_cast<ngraph::op::Broadcast>(a);
auto broadcast_b = std::dynamic_pointer_cast<ngraph::op::Broadcast>(b); auto broadcast_b = std::static_pointer_cast<ngraph::op::Broadcast>(b);
return (a->get_argument(0) == b->get_argument(0)) && return (a->get_argument(0) == b->get_argument(0)) &&
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) && (broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
...@@ -121,8 +121,8 @@ static bool cse_reduction(std::shared_ptr<Node> a, std::shared_ptr<Node> b) ...@@ -121,8 +121,8 @@ static bool cse_reduction(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{ {
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name(); NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
auto ar_a = std::dynamic_pointer_cast<op::util::ArithmeticReduction>(a); auto ar_a = std::static_pointer_cast<op::util::ArithmeticReduction>(a);
auto ar_b = std::dynamic_pointer_cast<op::util::ArithmeticReduction>(b); auto ar_b = std::static_pointer_cast<op::util::ArithmeticReduction>(b);
return ar_a->get_argument(0) == ar_b->get_argument(0) && return ar_a->get_argument(0) == ar_b->get_argument(0) &&
ar_a->get_reduction_axes() == ar_b->get_reduction_axes(); ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
HANDLER_DECL(replace_broadcast_like) HANDLER_DECL(replace_broadcast_like)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::dynamic_pointer_cast<ngraph::op::BroadcastLike>(node); auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node);
ngraph::replace_node( ngraph::replace_node(
node, node,
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0), std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0),
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
HANDLER_DECL(eliminate_pad) HANDLER_DECL(eliminate_pad)
{ {
auto pad = std::dynamic_pointer_cast<ngraph::op::Pad>(node); auto pad = std::static_pointer_cast<ngraph::op::Pad>(node);
if (pad->get_input_shape(0) == pad->get_output_shape(0)) if (pad->get_input_shape(0) == pad->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
...@@ -47,7 +47,7 @@ HANDLER_DECL(eliminate_pad) ...@@ -47,7 +47,7 @@ HANDLER_DECL(eliminate_pad)
HANDLER_DECL(eliminate_sum) HANDLER_DECL(eliminate_sum)
{ {
auto sum = std::dynamic_pointer_cast<ngraph::op::Sum>(node); auto sum = std::static_pointer_cast<ngraph::op::Sum>(node);
if (sum->get_reduction_axes().empty()) if (sum->get_reduction_axes().empty())
{ {
ngraph::replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
...@@ -58,7 +58,7 @@ HANDLER_DECL(eliminate_sum) ...@@ -58,7 +58,7 @@ HANDLER_DECL(eliminate_sum)
HANDLER_DECL(eliminate_convert) HANDLER_DECL(eliminate_convert)
{ {
auto convert = std::dynamic_pointer_cast<ngraph::op::Convert>(node); auto convert = std::static_pointer_cast<ngraph::op::Convert>(node);
if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type()) if (convert->get_convert_element_type() == convert->get_argument(0)->get_element_type())
{ {
ngraph::replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
...@@ -69,7 +69,7 @@ HANDLER_DECL(eliminate_convert) ...@@ -69,7 +69,7 @@ HANDLER_DECL(eliminate_convert)
HANDLER_DECL(eliminate_slice) HANDLER_DECL(eliminate_slice)
{ {
auto slice = std::dynamic_pointer_cast<ngraph::op::Slice>(node); auto slice = std::static_pointer_cast<ngraph::op::Slice>(node);
if (slice->get_input_shape(0) == slice->get_output_shape(0)) if (slice->get_input_shape(0) == slice->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
...@@ -81,7 +81,7 @@ HANDLER_DECL(eliminate_slice) ...@@ -81,7 +81,7 @@ HANDLER_DECL(eliminate_slice)
HANDLER_DECL(replace_broadcast_like) HANDLER_DECL(replace_broadcast_like)
{ {
// Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument // Replace a broadcast like with the broadcast to eliminate the pseudo-dependency on the "like" argument
auto broadcast_like = std::dynamic_pointer_cast<ngraph::op::BroadcastLike>(node); auto broadcast_like = std::static_pointer_cast<ngraph::op::BroadcastLike>(node);
ngraph::replace_node( ngraph::replace_node(
node, node,
std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0), std::make_shared<ngraph::op::Broadcast>(broadcast_like->get_argument(0),
...@@ -92,7 +92,7 @@ HANDLER_DECL(replace_broadcast_like) ...@@ -92,7 +92,7 @@ HANDLER_DECL(replace_broadcast_like)
HANDLER_DECL(eliminate_broadcast) HANDLER_DECL(eliminate_broadcast)
{ {
auto broadcast = std::dynamic_pointer_cast<ngraph::op::Broadcast>(node); auto broadcast = std::static_pointer_cast<ngraph::op::Broadcast>(node);
if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0)) if (broadcast->get_input_shape(0) == broadcast->get_output_shape(0))
{ {
ngraph::replace_node(node, node->get_argument(0)); ngraph::replace_node(node, node->get_argument(0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment