Unverified Commit a383b22b authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3577 from NervanaSystems/gauri/new_reshape_error_master

Fix reshape error for onnx (on master)
...@@ -47,6 +47,7 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -47,6 +47,7 @@ NodeVector op::Squeeze::decompose_op() const
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
if (axes.empty()) if (axes.empty())
...@@ -56,8 +57,11 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -56,8 +57,11 @@ NodeVector op::Squeeze::decompose_op() const
{ {
if (data_shape.at(idx) == 1) if (data_shape.at(idx) == 1)
{ {
// Mark with zero elements to remove; axes_to_squeeze.at(idx) = 1;
data_shape.at(idx) = 0; }
else
{
axes_to_squeeze.at(idx) = 0;
} }
} }
} }
...@@ -70,16 +74,14 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -70,16 +74,14 @@ NodeVector op::Squeeze::decompose_op() const
this, this,
(data_shape.at(axis) == 1), (data_shape.at(axis) == 1),
"provided axis value is invalid. Only axes of size 1 may be removed."); "provided axis value is invalid. Only axes of size 1 may be removed.");
axes_to_squeeze.at(axis) = 1;
// Mark with zero elements to remove;
data_shape.at(axis) = 0;
} }
} }
Shape output_data_shape; Shape output_data_shape;
for (size_t idx = 0; idx < data_shape.size(); ++idx) for (size_t idx = 0; idx < data_shape.size(); ++idx)
{ {
if (data_shape.at(idx) != 0) if (axes_to_squeeze.at(idx) == 0)
{ {
output_data_shape.push_back(data_shape.at(idx)); output_data_shape.push_back(data_shape.at(idx));
} }
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
template <typename ElementType> template <typename ElementType>
void result(const void* arg, void* out, size_t count, int arena) void result(const void* arg, void* out, size_t count, int arena)
{ {
if (arg != out) if (arg != out && count != 0)
{ {
memcpy(out, arg, sizeof(ElementType) * count); memcpy(out, arg, sizeof(ElementType) * count);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment