Unverified Commit 2fc73b43 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

optimization for about 2x speedup (#2036)

* optimization for about 2x speedup

* more optimizations
parent 0ac2a8b6
...@@ -45,6 +45,7 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape, ...@@ -45,6 +45,7 @@ CoordinateTransform::CoordinateTransform(const Shape& source_shape,
, m_target_padding_below(target_padding_below) , m_target_padding_below(target_padding_below)
, m_target_padding_above(target_padding_above) , m_target_padding_above(target_padding_above)
, m_target_dilation_strides(target_dilation_strides) , m_target_dilation_strides(target_dilation_strides)
, m_end_iterator(Shape(), true)
{ {
m_n_axes = source_shape.size(); m_n_axes = source_shape.size();
......
...@@ -93,8 +93,7 @@ namespace ngraph ...@@ -93,8 +93,7 @@ namespace ngraph
}; };
Iterator begin() noexcept { return Iterator(m_target_shape); } Iterator begin() noexcept { return Iterator(m_target_shape); }
Iterator end() noexcept { return Iterator(m_target_shape, true); } Iterator end() noexcept { return m_end_iterator; }
private:
size_t index_source(const Coordinate& c) const; size_t index_source(const Coordinate& c) const;
static Strides default_strides(size_t n_axes); static Strides default_strides(size_t n_axes);
static CoordinateDiff default_padding(size_t n_axes); static CoordinateDiff default_padding(size_t n_axes);
...@@ -113,5 +112,6 @@ namespace ngraph ...@@ -113,5 +112,6 @@ namespace ngraph
Shape m_target_shape; Shape m_target_shape;
size_t m_n_axes; size_t m_n_axes;
Iterator m_end_iterator;
}; };
} }
...@@ -59,7 +59,7 @@ namespace ngraph ...@@ -59,7 +59,7 @@ namespace ngraph
// At the outermost level we will walk over every output coordinate O. // At the outermost level we will walk over every output coordinate O.
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
for (Coordinate out_coord : output_transform) for (const Coordinate& out_coord : output_transform)
{ {
// Our output coordinate O will have the form: // Our output coordinate O will have the form:
// //
...@@ -169,9 +169,10 @@ namespace ngraph ...@@ -169,9 +169,10 @@ namespace ngraph
CoordinateTransform::Iterator input_it = input_batch_transform.begin(); CoordinateTransform::Iterator input_it = input_batch_transform.begin();
CoordinateTransform::Iterator filter_it = filter_transform.begin(); CoordinateTransform::Iterator filter_it = filter_transform.begin();
CoordinateTransform::Iterator input_it_end = input_batch_transform.end();
CoordinateTransform::Iterator filter_it_end = filter_transform.end();
while (input_it != input_batch_transform.end() && while (input_it != input_it_end && filter_it != filter_it_end)
filter_it != filter_transform.end())
{ {
const Coordinate& input_batch_coord = *input_it; const Coordinate& input_batch_coord = *input_it;
Coordinate filter_coord = *filter_it; Coordinate filter_coord = *filter_it;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment