Commit 686ee9ab authored by Robert Kimball's avatar Robert Kimball Committed by Adam Procter

Optimize the Coordinate class to prevent copies (#358)

parent 07ba1bef
...@@ -368,7 +368,7 @@ void CoordinateTransform::Iterator::operator+=(size_t n) ...@@ -368,7 +368,7 @@ void CoordinateTransform::Iterator::operator+=(size_t n)
} }
} }
Coordinate CoordinateTransform::Iterator::operator*() const Coordinate& CoordinateTransform::Iterator::operator*() const
{ {
return m_coordinate; return m_coordinate;
} }
......
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
void operator++(); void operator++();
Iterator operator++(int); Iterator operator++(int);
void operator+=(size_t n); void operator+=(size_t n);
Coordinate operator*(); const Coordinate& operator*() const;
bool operator!=(const Iterator& it); bool operator!=(const Iterator& it);
bool operator==(const Iterator& it); bool operator==(const Iterator& it);
......
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
Coordinate input_coord = project_coordinate(output_coord, broadcast_axes); Coordinate input_coord = project_coordinate(output_coord, broadcast_axes);
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
// The start coordinate for the copy is (0,...,0) except at the concatenation axis. // The start coordinate for the copy is (0,...,0) except at the concatenation axis.
Coordinate out_start_coord = Coordinate(out_shape.size(), 0); Coordinate out_start_coord(out_shape.size(), 0);
out_start_coord[concatenation_axis] = concatenation_pos; out_start_coord[concatenation_axis] = concatenation_pos;
// The end coordinate for the copy is the same as the output shape except at the // The end coordinate for the copy is the same as the output shape except at the
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
CoordinateTransform::Iterator output_chunk_it = output_chunk_transform.begin(); CoordinateTransform::Iterator output_chunk_it = output_chunk_transform.begin();
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
size_t input_index = input_transform.index(input_coord); size_t input_index = input_transform.index(input_coord);
size_t output_chunk_index = output_chunk_transform.index(*output_chunk_it); size_t output_chunk_index = output_chunk_transform.index(*output_chunk_it);
......
...@@ -149,12 +149,14 @@ namespace ngraph ...@@ -149,12 +149,14 @@ namespace ngraph
while (input_it != input_batch_transform.end() && while (input_it != input_batch_transform.end() &&
filter_it != filter_transform.end()) filter_it != filter_transform.end())
{ {
Coordinate input_batch_coord = *input_it++; const Coordinate& input_batch_coord = *input_it;
Coordinate filter_coord = *filter_it++; const Coordinate& filter_coord = *filter_it;
T v = input_batch_transform.in_padding(input_batch_coord) T v = input_batch_transform.in_padding(input_batch_coord)
? 0 ? 0
: arg0[input_batch_transform.index(input_batch_coord)]; : arg0[input_batch_transform.index(input_batch_coord)];
result += v * arg1[filter_transform.index(filter_coord)]; result += v * arg1[filter_transform.index(filter_coord)];
++input_it;
++filter_it;
} }
out[output_transform.index(out_coord)] = result; out[output_transform.index(out_coord)] = result;
......
...@@ -67,9 +67,9 @@ namespace ngraph ...@@ -67,9 +67,9 @@ namespace ngraph
// for the dotted axes. // for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes); CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (Coordinate arg0_projected_coord : arg0_projected_transform) for (const Coordinate& arg0_projected_coord : arg0_projected_transform)
{ {
for (Coordinate arg1_projected_coord : arg1_projected_transform) for (const Coordinate& arg1_projected_coord : arg1_projected_transform)
{ {
// The output coordinate is just the concatenation of the projected coordinates. // The output coordinate is just the concatenation of the projected coordinates.
Coordinate out_coord(arg0_projected_coord.size() + Coordinate out_coord(arg0_projected_coord.size() +
...@@ -87,16 +87,15 @@ namespace ngraph ...@@ -87,16 +87,15 @@ namespace ngraph
size_t out_index = output_transform.index(out_coord); size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes. // Walk along the dotted axes.
for (Coordinate dot_axis_positions : dot_axes_transform) Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
for (const Coordinate& dot_axis_positions : dot_axes_transform)
{ {
// In order to find the points to multiply together, we need to inject our current // In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates. // positions along the dotted axes back into the projected arg0 and arg1 coordinates.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
std::copy( std::copy(
dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it); dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it);
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform) for (const Coordinate& out_coord : output_transform)
{ {
// Our output coordinate O will have the form: // Our output coordinate O will have the form:
// //
...@@ -88,7 +88,7 @@ namespace ngraph ...@@ -88,7 +88,7 @@ namespace ngraph
? -std::numeric_limits<T>::infinity() ? -std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
for (Coordinate input_batch_coord : input_batch_transform) for (const Coordinate& input_batch_coord : input_batch_transform)
{ {
T x = arg[input_batch_transform.index(input_batch_coord)]; T x = arg[input_batch_transform.index(input_batch_coord)];
result = x > result ? x : result; result = x > result ? x : result;
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
// Step 1: Zero out the output. // Step 1: Zero out the output.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = 0; out[output_transform.index(output_coord)] = 0;
} }
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
// are encountered. // are encountered.
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
T val = arg[input_transform.index(input_coord)]; T val = arg[input_transform.index(input_coord)];
......
...@@ -36,14 +36,14 @@ namespace ngraph ...@@ -36,14 +36,14 @@ namespace ngraph
{ {
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = *arg1; out[output_transform.index(output_coord)] = *arg1;
} }
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = project_coordinate(input_coord, reduction_axes); Coordinate output_coord = project_coordinate(input_coord, reduction_axes);
size_t input_index = input_transform.index(input_coord); size_t input_index = input_transform.index(input_coord);
......
...@@ -50,12 +50,14 @@ namespace ngraph ...@@ -50,12 +50,14 @@ namespace ngraph
CoordinateTransform::Iterator output_it = output_transform.begin(); CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = *output_it++; const Coordinate& output_coord = *output_it;
out[output_transform.index(output_coord)] = out[output_transform.index(output_coord)] =
arg1[input_transform.index(input_coord)]; arg1[input_transform.index(input_coord)];
++output_it;
} }
} }
} }
......
...@@ -43,12 +43,14 @@ namespace ngraph ...@@ -43,12 +43,14 @@ namespace ngraph
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
CoordinateTransform::Iterator output_it = output_transform.begin(); CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = *output_it++; const Coordinate& output_coord = *output_it;
out[output_transform.index(output_coord)] = out[output_transform.index(output_coord)] =
arg[input_transform.index(input_coord)]; arg[input_transform.index(input_coord)];
++output_it;
} }
} }
} }
......
...@@ -39,11 +39,13 @@ namespace ngraph ...@@ -39,11 +39,13 @@ namespace ngraph
CoordinateTransform::Iterator output_it = output_transform.begin(); CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate in_coord : input_transform) for (const Coordinate& in_coord : input_transform)
{ {
Coordinate out_coord = *output_it++; const Coordinate& out_coord = *output_it;
out[output_transform.index(out_coord)] = arg[input_transform.index(in_coord)]; out[output_transform.index(out_coord)] = arg[input_transform.index(in_coord)];
++output_it;
} }
} }
} }
......
...@@ -34,14 +34,14 @@ namespace ngraph ...@@ -34,14 +34,14 @@ namespace ngraph
{ {
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform) for (const Coordinate& output_coord : output_transform)
{ {
out[output_transform.index(output_coord)] = 0; out[output_transform.index(output_coord)] = 0;
} }
CoordinateTransform input_transform(in_shape); CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform) for (const Coordinate& input_coord : input_transform)
{ {
Coordinate output_coord = project_coordinate(input_coord, reduction_axes); Coordinate output_coord = project_coordinate(input_coord, reduction_axes);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment