Commit 45752b0f authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

switch more passes to static_pointer_cast (#2041)

parent cccdc304
...@@ -530,7 +530,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -530,7 +530,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
// now get the GOE0 which is the first output of lstm (ht) // now get the GOE0 which is the first output of lstm (ht)
for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs()) for (auto& goes : lstm_nodes[index]->get_outputs().at(0).get_inputs())
{ {
auto goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(goes->get_node()); auto goe_node = std::static_pointer_cast<op::GetOutputElement>(goes->get_node());
// first output node of lstm // first output node of lstm
if (goe_node->get_n() == 0) if (goe_node->get_n() == 0)
{ {
......
...@@ -94,13 +94,13 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::run_on_function(std::shared_ptr< ...@@ -94,13 +94,13 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::run_on_function(std::shared_ptr<
bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m) bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
{ {
auto data = std::dynamic_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(0)); auto data = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(0));
auto delta = std::dynamic_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(1)); auto delta = std::static_pointer_cast<pattern::op::Label>(m.get_pattern()->get_argument(1));
NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against " NGRAPH_DEBUG << "In a callback for construct_max_pool_with_indices against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto m_max_pool_bprop = std::dynamic_pointer_cast<op::MaxPoolBackprop>(m.get_match_root()); auto m_max_pool_bprop = std::static_pointer_cast<op::MaxPoolBackprop>(m.get_match_root());
if (m_max_pool_bprop->get_shape().size() != 4 || if (m_max_pool_bprop->get_shape().size() != 4 ||
m_max_pool_bprop->get_window_shape().size() != 2 || m_max_pool_bprop->get_window_shape().size() != 2 ||
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment