Commit abf1a2fb authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

zero dim elem fix (#1663)

* zero dim elem fix

* switch to find

* fix runtime errors
parent f16ace9c
......@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
{
throw ngraph_error("has_zero_dim is called on multi-output op");
}
return shape_size(node->get_shape()) == 0;
const auto& shape = node->get_shape();
return std::find(shape.begin(), shape.end(), 0) != shape.end();
}
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f)
......@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
auto cvals = std::vector<std::string>(0);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops())
......@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{
// we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones
auto cvals = std::vector<std::string>(0);
auto constant =
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
replace_node(n, constant);
......@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue;
}
if (n->get_inputs().size() == 0)
{
continue;
}
auto arg = n->get_inputs().at(0).get_output().get_node();
if (arg->get_outputs().size() != 1 || !has_zero_dim(arg))
{
continue;
}
auto new_node = n->get_default_value();
if (!new_node || !has_zero_dim(n->get_argument(0)))
if (!new_node)
{
continue;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment