Commit fa7da175 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Adam Procter

fix ztde concat (#1918)

parent 95c60166
......@@ -21,6 +21,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
......@@ -71,6 +72,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
zero_length_nodes.erase(n);
}
}
return zero_length_nodes.size() > 0;
}
......@@ -109,6 +111,27 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue;
}
if (auto concat = std::dynamic_pointer_cast<op::Concat>(n))
{
NodeVector non_zero_dim_args;
for (auto arg : concat->get_arguments())
{
if (!has_zero_dim(arg))
{
non_zero_dim_args.push_back(arg);
}
}
if (non_zero_dim_args.size() < concat->get_inputs().size())
{
auto new_concat = concat->copy_with_new_args(non_zero_dim_args);
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
<< new_concat->get_name();
ngraph::replace_node(concat, new_concat);
continue;
}
}
auto arg = n->get_inputs().at(0).get_output().get_node();
if (arg->get_outputs().size() != 1 || !has_zero_dim(arg))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment