Unverified Commit a383b22b authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3577 from NervanaSystems/gauri/new_reshape_error_master

Fix reshape error for onnx (on master)
parents a15d2ec5 76e144d3
......@@ -47,6 +47,7 @@ NodeVector op::Squeeze::decompose_op() const
auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
// Prepare set of unique axes marked to be removed from input data.
if (axes.empty())
......@@ -56,8 +57,11 @@ NodeVector op::Squeeze::decompose_op() const
{
if (data_shape.at(idx) == 1)
{
// Mark with zero elements to remove;
data_shape.at(idx) = 0;
axes_to_squeeze.at(idx) = 1;
}
else
{
axes_to_squeeze.at(idx) = 0;
}
}
}
......@@ -70,16 +74,14 @@ NodeVector op::Squeeze::decompose_op() const
this,
(data_shape.at(axis) == 1),
"provided axis value is invalid. Only axes of size 1 may be removed.");
// Mark with zero elements to remove;
data_shape.at(axis) = 0;
axes_to_squeeze.at(axis) = 1;
}
}
Shape output_data_shape;
for (size_t idx = 0; idx < data_shape.size(); ++idx)
{
if (data_shape.at(idx) != 0)
if (axes_to_squeeze.at(idx) == 0)
{
output_data_shape.push_back(data_shape.at(idx));
}
......
......@@ -30,7 +30,7 @@ namespace ngraph
template <typename ElementType>
void result(const void* arg, void* out, size_t count, int arena)
{
if (arg != out)
if (arg != out && count != 0)
{
memcpy(out, arg, sizeof(ElementType) * count);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment