Commit d1aeaa8b authored by gaurides's avatar gaurides Committed by Scott Cyphers

Fix reshape error for onnx (#3573)

* Fix reshape error for onnx

* Change to fix CI error

* Removed old comments
parent 6818cc5f
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
// Parameters which ngraph-unittest uses: // Parameters which ngraph-unittest uses:
String PR_URL = CHANGE_URL String PR_URL = CHANGE_URL
String PR_COMMIT_AUTHOR = CHANGE_AUTHOR String PR_COMMIT_AUTHOR = CHANGE_AUTHOR
String JENKINS_BRANCH = "ngraph-core-r0.22.0.rc1" String JENKINS_BRANCH = "aslepko/r0.22"
Integer TIMEOUTTIME = "3600" Integer TIMEOUTTIME = "3600"
// BRANCH parameter is no loner needed // BRANCH parameter is no loner needed
// TRIGGER_URL parameter is no longer needed // TRIGGER_URL parameter is no longer needed
......
...@@ -47,6 +47,7 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -47,6 +47,7 @@ NodeVector op::Squeeze::decompose_op() const
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
if (axes.empty()) if (axes.empty())
...@@ -56,8 +57,11 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -56,8 +57,11 @@ NodeVector op::Squeeze::decompose_op() const
{ {
if (data_shape.at(idx) == 1) if (data_shape.at(idx) == 1)
{ {
// Mark with zero elements to remove; axes_to_squeeze.at(idx) = 1;
data_shape.at(idx) = 0; }
else
{
axes_to_squeeze.at(idx) = 0;
} }
} }
} }
...@@ -71,15 +75,14 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -71,15 +75,14 @@ NodeVector op::Squeeze::decompose_op() const
(data_shape.at(axis) == 1), (data_shape.at(axis) == 1),
"provided axis value is invalid. Only axes of size 1 may be removed."); "provided axis value is invalid. Only axes of size 1 may be removed.");
// Mark with zero elements to remove; axes_to_squeeze.at(axis) = 1;
data_shape.at(axis) = 0;
} }
} }
Shape output_data_shape; Shape output_data_shape;
for (size_t idx = 0; idx < data_shape.size(); ++idx) for (size_t idx = 0; idx < data_shape.size(); ++idx)
{ {
if (data_shape.at(idx) != 0) if (axes_to_squeeze.at(idx) == 0)
{ {
output_data_shape.push_back(data_shape.at(idx)); output_data_shape.push_back(data_shape.at(idx));
} }
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
template <typename ElementType> template <typename ElementType>
void result(const void* arg, void* out, size_t count, int arena) void result(const void* arg, void* out, size_t count, int arena)
{ {
if (arg != out) if (arg != out && count != 0)
{ {
memcpy(out, arg, sizeof(ElementType) * count); memcpy(out, arg, sizeof(ElementType) * count);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment