Commit e54156cf authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

more passes to static (#2027)

parent 3a47eafc
...@@ -56,7 +56,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion ...@@ -56,7 +56,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
NGRAPH_DEBUG << "conv_horizontal_fusion: In a callback for conv horizontal fusion for " NGRAPH_DEBUG << "conv_horizontal_fusion: In a callback for conv horizontal fusion for "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto conv_bias_root = std::dynamic_pointer_cast<op::ConvolutionBias>(m.get_match_root()); auto conv_bias_root = std::static_pointer_cast<op::ConvolutionBias>(m.get_match_root());
//check if the node has been replaced //check if the node has been replaced
if (conv_bias_root->get_users().empty()) if (conv_bias_root->get_users().empty())
......
...@@ -331,7 +331,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi ...@@ -331,7 +331,7 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
for (auto op : op_nodes) for (auto op : op_nodes)
{ {
const auto old_slice = const auto old_slice =
std::dynamic_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA)); std::static_pointer_cast<op::Slice>(op_seg_map[op].at(Type::DATA));
const auto& old_lower_bounds = old_slice->get_lower_bounds(); const auto& old_lower_bounds = old_slice->get_lower_bounds();
// lower bound matching the current time step // lower bound matching the current time step
const Coordinate lower_bounds{old_lower_bounds[1], 0}; const Coordinate lower_bounds{old_lower_bounds[1], 0};
...@@ -403,7 +403,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -403,7 +403,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
std::shared_ptr<Node> data; std::shared_ptr<Node> data;
std::shared_ptr<Node> weights; std::shared_ptr<Node> weights;
auto concat = std::dynamic_pointer_cast<op::Concat>(n); auto concat = std::static_pointer_cast<op::Concat>(n);
std::shared_ptr<op::Convolution> sconv; std::shared_ptr<op::Convolution> sconv;
NodeVector slices; NodeVector slices;
...@@ -423,14 +423,14 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -423,14 +423,14 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
return {nullptr}; return {nullptr};
} }
sconv = std::dynamic_pointer_cast<op::Convolution>(arg); sconv = std::static_pointer_cast<op::Convolution>(arg);
if (arg->get_input_shape(0).size() != 4) if (arg->get_input_shape(0).size() != 4)
{ {
NGRAPH_DEBUG << "convolution data's rank isn't equal to 4"; NGRAPH_DEBUG << "convolution data's rank isn't equal to 4";
return {nullptr}; return {nullptr};
} }
if (!is_trivial_convolution(std::dynamic_pointer_cast<op::Convolution>(arg))) if (!is_trivial_convolution(sconv))
{ {
NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution"; NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution";
return {nullptr}; return {nullptr};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment