Commit df6d7281 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Jayaram Bobba

double check that the third arg is indeed maxpool (#2116)

parent bdfa4719
......@@ -61,7 +61,9 @@ static std::shared_ptr<pattern::Matcher> create_maxpool_with_indices_matcher()
Shape window_shape{3};
auto max_pool = std::make_shared<op::MaxPool>(data, window_shape);
auto delta = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto max_pool_label = std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape());
auto is_max_pool = pattern::has_class<op::MaxPool>();
auto max_pool_label =
std::make_shared<pattern::op::Label>(element::f32, max_pool->get_shape(), is_max_pool);
auto max_pool_bprop =
std::make_shared<op::MaxPoolBackprop>(data,
delta,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment