Commit 686ee9ab authored by Robert Kimball's avatar Robert Kimball Committed by Adam Procter

Optimize the Coordinate class to prevent copies (#358)

parent 07ba1bef
......@@ -368,7 +368,7 @@ void CoordinateTransform::Iterator::operator+=(size_t n)
}
}
Coordinate CoordinateTransform::Iterator::operator*()
const Coordinate& CoordinateTransform::Iterator::operator*() const
{
return m_coordinate;
}
......
......@@ -70,7 +70,7 @@ namespace ngraph
void operator++();
Iterator operator++(int);
void operator+=(size_t n);
Coordinate operator*();
const Coordinate& operator*() const;
bool operator!=(const Iterator& it);
bool operator==(const Iterator& it);
......
......@@ -35,7 +35,7 @@ namespace ngraph
CoordinateTransform input_transform(in_shape);
CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform)
for (const Coordinate& output_coord : output_transform)
{
Coordinate input_coord = project_coordinate(output_coord, broadcast_axes);
......
......@@ -39,7 +39,7 @@ namespace ngraph
for (size_t i = 0; i < args.size(); i++)
{
// The start coordinate for the copy is (0,...,0) except at the concatenation axis.
Coordinate out_start_coord = Coordinate(out_shape.size(), 0);
Coordinate out_start_coord(out_shape.size(), 0);
out_start_coord[concatenation_axis] = concatenation_pos;
// The end coordinate for the copy is the same as the output shape except at the
......@@ -54,7 +54,7 @@ namespace ngraph
CoordinateTransform::Iterator output_chunk_it = output_chunk_transform.begin();
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
size_t input_index = input_transform.index(input_coord);
size_t output_chunk_index = output_chunk_transform.index(*output_chunk_it);
......
......@@ -149,12 +149,14 @@ namespace ngraph
while (input_it != input_batch_transform.end() &&
filter_it != filter_transform.end())
{
Coordinate input_batch_coord = *input_it++;
Coordinate filter_coord = *filter_it++;
const Coordinate& input_batch_coord = *input_it;
const Coordinate& filter_coord = *filter_it;
T v = input_batch_transform.in_padding(input_batch_coord)
? 0
: arg0[input_batch_transform.index(input_batch_coord)];
result += v * arg1[filter_transform.index(filter_coord)];
++input_it;
++filter_it;
}
out[output_transform.index(out_coord)] = result;
......
......@@ -67,9 +67,9 @@ namespace ngraph
// for the dotted axes.
CoordinateTransform dot_axes_transform(dot_axis_sizes);
for (Coordinate arg0_projected_coord : arg0_projected_transform)
for (const Coordinate& arg0_projected_coord : arg0_projected_transform)
{
for (Coordinate arg1_projected_coord : arg1_projected_transform)
for (const Coordinate& arg1_projected_coord : arg1_projected_transform)
{
// The output coordinate is just the concatenation of the projected coordinates.
Coordinate out_coord(arg0_projected_coord.size() +
......@@ -87,16 +87,15 @@ namespace ngraph
size_t out_index = output_transform.index(out_coord);
// Walk along the dotted axes.
for (Coordinate dot_axis_positions : dot_axes_transform)
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
for (const Coordinate& dot_axis_positions : dot_axes_transform)
{
// In order to find the points to multiply together, we need to inject our current
// positions along the dotted axes back into the projected arg0 and arg1 coordinates.
Coordinate arg0_coord(arg0_shape.size());
Coordinate arg1_coord(arg1_shape.size());
auto arg0_it = std::copy(arg0_projected_coord.begin(),
arg0_projected_coord.end(),
arg0_coord.begin());
std::copy(
dot_axis_positions.begin(), dot_axis_positions.end(), arg0_it);
......
......@@ -36,7 +36,7 @@ namespace ngraph
// At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform)
for (const Coordinate& out_coord : output_transform)
{
// Our output coordinate O will have the form:
//
......@@ -88,7 +88,7 @@ namespace ngraph
? -std::numeric_limits<T>::infinity()
: std::numeric_limits<T>::min();
for (Coordinate input_batch_coord : input_batch_transform)
for (const Coordinate& input_batch_coord : input_batch_transform)
{
T x = arg[input_batch_transform.index(input_batch_coord)];
result = x > result ? x : result;
......
......@@ -34,7 +34,7 @@ namespace ngraph
// Step 1: Zero out the output.
CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform)
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
}
......@@ -43,7 +43,7 @@ namespace ngraph
// are encountered.
CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
T val = arg[input_transform.index(input_coord)];
......
......@@ -36,14 +36,14 @@ namespace ngraph
{
CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform)
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = *arg1;
}
CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = project_coordinate(input_coord, reduction_axes);
size_t input_index = input_transform.index(input_coord);
......
......@@ -50,12 +50,14 @@ namespace ngraph
CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = *output_it++;
const Coordinate& output_coord = *output_it;
out[output_transform.index(output_coord)] =
arg1[input_transform.index(input_coord)];
++output_it;
}
}
}
......
......@@ -43,12 +43,14 @@ namespace ngraph
CoordinateTransform output_transform(out_shape);
CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = *output_it++;
const Coordinate& output_coord = *output_it;
out[output_transform.index(output_coord)] =
arg[input_transform.index(input_coord)];
++output_it;
}
}
}
......
......@@ -39,11 +39,13 @@ namespace ngraph
CoordinateTransform::Iterator output_it = output_transform.begin();
for (Coordinate in_coord : input_transform)
for (const Coordinate& in_coord : input_transform)
{
Coordinate out_coord = *output_it++;
const Coordinate& out_coord = *output_it;
out[output_transform.index(out_coord)] = arg[input_transform.index(in_coord)];
++output_it;
}
}
}
......
......@@ -34,14 +34,14 @@ namespace ngraph
{
CoordinateTransform output_transform(out_shape);
for (Coordinate output_coord : output_transform)
for (const Coordinate& output_coord : output_transform)
{
out[output_transform.index(output_coord)] = 0;
}
CoordinateTransform input_transform(in_shape);
for (Coordinate input_coord : input_transform)
for (const Coordinate& input_coord : input_transform)
{
Coordinate output_coord = project_coordinate(input_coord, reduction_axes);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment