Unverified Commit ea40bc41 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge pull request #1708 from NervanaSystems/aprocter/cherry-pick-1663

Cherry-pick "zero dim elem fix (#1663)"
parents f3c88459 427bcc1f
...@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node) ...@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
{ {
throw ngraph_error("has_zero_dim is called on multi-output op"); throw ngraph_error("has_zero_dim is called on multi-output op");
} }
return shape_size(node->get_shape()) == 0;
const auto& shape = node->get_shape();
return std::find(shape.begin(), shape.end(), 0) != shape.end();
} }
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f) static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f)
...@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> ...@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f) bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{ {
bool replaced = false; bool replaced = false;
auto cvals = std::vector<std::string>(0);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op // we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`) // as an internal node (i.e. a node that isn't an argument to `op::Result`)
for (auto n : f->get_ordered_ops()) for (auto n : f->get_ordered_ops())
...@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr ...@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{ {
// we don't have to create constants every time but this is the easiest // we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones // and it's CSE's job to eliminate the same ones
auto cvals = std::vector<std::string>(0);
auto constant = auto constant =
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals); std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
replace_node(n, constant); replace_node(n, constant);
...@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr ...@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue; continue;
} }
if (n->get_inputs().size() == 0)
{
continue;
}
auto arg = n->get_inputs().at(0).get_output().get_node();
if (arg->get_outputs().size() != 1 || !has_zero_dim(arg))
{
continue;
}
auto new_node = n->get_default_value(); auto new_node = n->get_default_value();
if (!new_node || !has_zero_dim(n->get_argument(0)))
if (!new_node)
{ {
continue; continue;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment