Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
b485bb33
Commit
b485bb33
authored
Nov 30, 2017
by
Adam Procter
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
De-Eigenize broadcast, and extend it to higher dimensions
parent
c50164bc
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
548 additions
and
4 deletions
+548
-4
CMakeLists.txt
src/ngraph/CMakeLists.txt
+1
-0
coordinate_iterator.cpp
src/ngraph/coordinate_iterator.cpp
+110
-0
coordinate_iterator.hpp
src/ngraph/coordinate_iterator.hpp
+49
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
broadcast.hpp
src/ngraph/runtime/kernel/broadcast.hpp
+77
-0
external_function.cpp
src/ngraph/runtime/ngvm/external_function.cpp
+12
-3
broadcast.hpp
src/ngraph/runtime/ngvm/instruction/broadcast.hpp
+67
-0
CMakeLists.txt
test/CMakeLists.txt
+2
-1
backend_test.in.cpp
test/backend_test.in.cpp
+72
-0
coordinate_iterator.cpp
test/coordinate_iterator.cpp
+157
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
b485bb33
...
...
@@ -15,6 +15,7 @@ set (SRC
autodiff/adjoints.cpp
builder/autobroadcast.cpp
builder/reduce_ops.cpp
coordinate_iterator.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
...
...
src/ngraph/coordinate_iterator.cpp
0 → 100644
View file @
b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/except.hpp"
using
namespace
ngraph
;
CoordinateIterator
::
CoordinateIterator
(
const
Shape
&
space_shape
,
const
Strides
&
strides
,
const
Coordinate
&
window_outer_corner
,
const
Coordinate
&
window_inner_corner
)
:
m_space_shape
(
space_shape
)
,
m_strides
(
strides
)
,
m_window_outer_corner
(
window_outer_corner
)
,
m_window_inner_corner
(
window_inner_corner
)
,
m_current_coordinate
(
window_inner_corner
)
{
assert
(
space_shape
.
size
()
==
window_inner_corner
.
size
());
assert
(
space_shape
.
size
()
==
window_outer_corner
.
size
());
assert
(
space_shape
.
size
()
==
strides
.
size
());
for
(
size_t
i
=
0
;
i
<
space_shape
.
size
();
i
++
)
{
if
(
window_inner_corner
[
i
]
>
window_outer_corner
[
i
])
{
throw
ngraph_error
(
"Coordinate iterator inner corner is outside outer corner"
);
}
if
(
window_inner_corner
[
i
]
>=
m_space_shape
[
i
])
{
throw
ngraph_error
(
"Coordinate iterator inner corner is out of bounds"
);
}
if
(
window_outer_corner
[
i
]
>
m_space_shape
[
i
])
{
throw
ngraph_error
(
"Coordinate iterator outer corner is out of bounds"
);
}
if
(
m_strides
[
i
]
==
0
)
{
throw
ngraph_error
(
"Coordinate iterator stride is zero"
);
}
}
}
CoordinateIterator
::
CoordinateIterator
(
const
Shape
&
space_shape
)
:
CoordinateIterator
(
space_shape
,
Strides
(
space_shape
.
size
(),
1
),
space_shape
,
Coordinate
(
space_shape
.
size
(),
0
))
{
}
CoordinateIterator
::
CoordinateIterator
(
const
Shape
&
space_shape
,
const
Strides
&
strides
)
:
CoordinateIterator
(
space_shape
,
strides
,
space_shape
,
Coordinate
(
space_shape
.
size
(),
0
))
{
}
size_t
CoordinateIterator
::
get_current_index
()
const
{
size_t
index
=
0
;
size_t
stride
=
1
;
for
(
size_t
i
=
m_space_shape
.
size
();
i
--
>
0
;)
{
index
+=
m_current_coordinate
[
i
]
*
stride
;
stride
*=
m_space_shape
[
i
];
}
return
index
;
}
bool
CoordinateIterator
::
increment
()
{
bool
overflow
=
true
;
for
(
size_t
i
=
m_space_shape
.
size
();
i
--
>
0
;)
{
m_current_coordinate
[
i
]
+=
m_strides
[
i
];
if
(
m_current_coordinate
[
i
]
>=
m_window_outer_corner
[
i
])
{
m_current_coordinate
[
i
]
=
m_window_inner_corner
[
i
];
}
else
{
overflow
=
false
;
break
;
}
}
return
!
overflow
;
}
src/ngraph/coordinate_iterator.hpp
0 → 100644
View file @
b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
namespace
ngraph
{
class
CoordinateIterator
{
public
:
CoordinateIterator
(
const
Shape
&
space_shape
,
const
Strides
&
strides
,
const
Coordinate
&
window_outer_corner
,
const
Coordinate
&
window_inner_corner
);
CoordinateIterator
(
const
Shape
&
space_shape
);
CoordinateIterator
(
const
Shape
&
space_shape
,
const
Strides
&
strides
);
Coordinate
get_current_coordinate
()
const
{
return
m_current_coordinate
;
}
size_t
get_current_index
()
const
;
bool
increment
();
private
:
const
Shape
m_space_shape
;
const
Strides
m_strides
;
const
Coordinate
m_window_outer_corner
;
const
Coordinate
m_window_inner_corner
;
Coordinate
m_current_coordinate
;
};
}
src/ngraph/ngraph.hpp
View file @
b485bb33
...
...
@@ -44,6 +44,7 @@
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
...
...
src/ngraph/runtime/kernel/broadcast.hpp
0 → 100644
View file @
b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
kernel
{
template
<
typename
T
>
void
broadcast
(
T
*
arg
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
// For the outer loop we will walk over the entire input shape.
CoordinateIterator
arg_iter
(
in_shape
);
do
{
// For the inner loop we will walk across the entire axis for the new broadcast axes, and stay put at the current arg position for the existing axes.
Coordinate
arg_coordinate
=
arg_iter
.
get_current_coordinate
();
Strides
out_strides
(
out_shape
.
size
(),
1
);
Coordinate
out_outer_corner
(
out_shape
.
size
());
Coordinate
out_inner_corner
(
out_shape
.
size
());
size_t
arg_pos
=
0
;
for
(
size_t
i
=
0
;
i
<
out_shape
.
size
();
i
++
)
{
if
(
broadcast_axes
.
find
(
i
)
==
broadcast_axes
.
end
())
{
// This is an existing axis.
out_outer_corner
[
i
]
=
arg_coordinate
[
arg_pos
];
out_inner_corner
[
i
]
=
arg_coordinate
[
arg_pos
];
arg_pos
++
;
}
else
{
// This is a new broadcast axis.
out_outer_corner
[
i
]
=
out_shape
[
i
];
out_inner_corner
[
i
]
=
0
;
}
}
CoordinateIterator
out_iter
(
out_shape
,
out_strides
,
out_outer_corner
,
out_inner_corner
);
do
{
out
[
out_iter
.
get_current_index
()]
=
arg
[
arg_iter
.
get_current_index
()];
}
while
(
out_iter
.
increment
());
}
while
(
arg_iter
.
increment
());
}
}
}
}
src/ngraph/runtime/ngvm/external_function.cpp
View file @
b485bb33
...
...
@@ -96,6 +96,7 @@
#include "ngraph/runtime/ngvm/instruction/add.hpp"
#include "ngraph/runtime/ngvm/instruction/asin.hpp"
#include "ngraph/runtime/ngvm/instruction/atan.hpp"
#include "ngraph/runtime/ngvm/instruction/broadcast.hpp"
#include "ngraph/runtime/ngvm/instruction/call.hpp"
#include "ngraph/runtime/ngvm/instruction/ceiling.hpp"
#include "ngraph/runtime/ngvm/instruction/constant.hpp"
...
...
@@ -420,15 +421,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto
arg_tensor_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
n
->
get_arguments
().
at
(
0
)
->
get_value_type
());
assert
(
nullptr
!=
arg_tensor_type
);
auto
arg_shape
=
arg_tensor_type
->
get_shape
();
auto
result_tensor_type
=
dynamic_pointer_cast
<
const
TensorViewType
>
(
n
->
get_value_type
());
assert
(
nullptr
!=
result_tensor_type
);
auto
arg_shape
=
arg_tensor_type
->
get_shape
();
auto
result_shape
=
result_tensor_type
->
get_shape
();
auto
&
result_element_type
=
result_tensor_type
->
get_element_type
();
PUSH_POLYMORPHIC_INSTRUCTION
(
result_element_type
,
"Broadcast has unhandled element type"
,
instruction
::
BroadcastInstruction
,
in
[
0
],
out
[
0
],
arg_shape
,
result_shape
,
broadcast
->
get_broadcast_axes
());
/*
if (broadcast->get_broadcast_axes().empty())
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
...
...
@@ -473,7 +482,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
else
{
throw ngraph_error("Broadcast not implemented for rank>2 in VM yet");
}
}
*/
};
REGISTER_TO_OP_MAP
(
op
::
Concat
)
...
...
src/ngraph/runtime/ngvm/instruction/broadcast.hpp
0 → 100644
View file @
b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
ngvm
{
namespace
instruction
{
template
<
typename
ET
>
class
BroadcastInstruction
:
public
Instruction
{
public
:
BroadcastInstruction
(
const
TensorViewInfo
&
arg
,
const
TensorViewInfo
&
out
,
const
Shape
&
arg_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
:
m_arg
(
arg
)
,
m_out
(
out
)
,
m_arg_shape
(
arg_shape
)
,
m_out_shape
(
out_shape
)
,
m_broadcast_axes
(
broadcast_axes
)
{
}
virtual
void
execute
(
CallFrame
&
call_frame
)
const
override
{
typename
ET
::
type
*
arg
=
get_tensor_data_ptr
<
ET
>
(
call_frame
,
m_arg
);
typename
ET
::
type
*
out
=
get_tensor_data_ptr
<
ET
>
(
call_frame
,
m_out
);
kernel
::
broadcast
<
typename
ET
::
type
>
(
arg
,
out
,
m_arg_shape
,
m_out_shape
,
m_broadcast_axes
);
}
protected
:
TensorViewInfo
m_arg
;
TensorViewInfo
m_out
;
Shape
m_arg_shape
;
Shape
m_out_shape
;
AxisSet
m_broadcast_axes
;
};
}
}
}
}
test/CMakeLists.txt
View file @
b485bb33
...
...
@@ -22,10 +22,11 @@ include_directories(
)
set
(
SRC
autodiff.cpp
builder_autobroadcast.cpp
builder_reduce_ops.cpp
autodiff.cpp
build_graph.cpp
coordinate_iterator.cpp
copy.cpp
eigen.cpp
element_type.cpp
...
...
test/backend_test.in.cpp
View file @
b485bb33
...
...
@@ -1512,6 +1512,78 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64)
result
->
get_vector
<
element
::
Int64
::
type
>
());
}
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_0
)
{
auto
shape_a
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape_a
);
auto
shape_r
=
Shape
{
2
,
2
,
2
};
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape_r
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Broadcast
>
(
A
,
shape_r
,
AxisSet
{
0
}),
rt
,
op
::
Parameters
{
A
});
auto
manager
=
runtime
::
Manager
::
get
(
"${BACKEND_NAME}"
);
auto
external
=
manager
->
compile
(
f
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
// Create some tensors for input/output
auto
a
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_a
);
copy_data
(
a
,
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
});
auto
result
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_r
);
cf
->
call
({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
}),
result
->
get_vector
<
element
::
Float32
::
type
>
());
}
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_1
)
{
auto
shape_a
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape_a
);
auto
shape_r
=
Shape
{
2
,
2
,
2
};
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape_r
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Broadcast
>
(
A
,
shape_r
,
AxisSet
{
1
}),
rt
,
op
::
Parameters
{
A
});
auto
manager
=
runtime
::
Manager
::
get
(
"${BACKEND_NAME}"
);
auto
external
=
manager
->
compile
(
f
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
// Create some tensors for input/output
auto
a
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_a
);
copy_data
(
a
,
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
});
auto
result
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_r
);
cf
->
call
({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
1
,
2
,
3
,
4
,
3
,
4
}),
result
->
get_vector
<
element
::
Float32
::
type
>
());
}
TEST
(
$
{
BACKEND_NAME
},
broadcast_matrix_2
)
{
auto
shape_a
=
Shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
Float32
::
element_type
(),
shape_a
);
auto
shape_r
=
Shape
{
2
,
2
,
2
};
auto
rt
=
make_shared
<
TensorViewType
>
(
element
::
Float32
::
element_type
(),
shape_r
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Broadcast
>
(
A
,
shape_r
,
AxisSet
{
2
}),
rt
,
op
::
Parameters
{
A
});
auto
manager
=
runtime
::
Manager
::
get
(
"${BACKEND_NAME}"
);
auto
external
=
manager
->
compile
(
f
);
auto
backend
=
manager
->
allocate_backend
();
auto
cf
=
backend
->
make_call_frame
(
external
);
// Create some tensors for input/output
auto
a
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_a
);
copy_data
(
a
,
vector
<
element
::
Float32
::
type
>
{
1
,
2
,
3
,
4
});
auto
result
=
backend
->
make_primary_tensor_view
(
element
::
Float32
::
element_type
(),
shape_r
);
cf
->
call
({
a
},
{
result
});
ASSERT_EQ
((
vector
<
element
::
Float32
::
type
>
{
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
}),
result
->
get_vector
<
element
::
Float32
::
type
>
());
}
TEST
(
$
{
BACKEND_NAME
},
convert_int32_float32
)
{
auto
shape
=
Shape
{
2
,
2
};
...
...
test/coordinate_iterator.cpp
0 → 100644
View file @
b485bb33
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
coordinate_iterator
,
construct
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
1
,
1
,
1
};
Coordinate
window_outer_corner
{
2
,
3
,
5
,
6
};
Coordinate
window_inner_corner
{
0
,
0
,
0
,
0
};
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
,
window_outer_corner
,
window_inner_corner
);
}
TEST
(
coordinate_iterator
,
construct_defaults
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
2
,
2
,
2
,
1
};
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
);
}
TEST
(
coordinate_iterator
,
construct_defaults_stride
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
auto
ci
=
CoordinateIterator
(
space_shape
);
}
TEST
(
coordinate_iterator
,
construct_bad_outer_oob
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
1
,
1
,
1
};
Coordinate
window_outer_corner
{
2
,
4
,
5
,
6
};
Coordinate
window_inner_corner
{
0
,
0
,
0
,
0
};
EXPECT_ANY_THROW
({
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
,
window_outer_corner
,
window_inner_corner
);
});
}
TEST
(
coordinate_iterator
,
construct_bad_inner_oob
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
1
,
1
,
1
};
Coordinate
window_outer_corner
{
2
,
3
,
5
,
6
};
Coordinate
window_inner_corner
{
0
,
3
,
0
,
0
};
EXPECT_ANY_THROW
({
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
,
window_outer_corner
,
window_inner_corner
);
});
}
TEST
(
coordinate_iterator
,
construct_bad_inner_outside_outer
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
1
,
1
,
1
};
Coordinate
window_outer_corner
{
2
,
1
,
5
,
6
};
Coordinate
window_inner_corner
{
0
,
2
,
0
,
0
};
EXPECT_ANY_THROW
({
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
,
window_outer_corner
,
window_inner_corner
);
});
}
TEST
(
coordinate_iterator
,
construct_bad_zero_stride
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
0
,
1
,
1
};
Coordinate
window_outer_corner
{
2
,
3
,
5
,
6
};
Coordinate
window_inner_corner
{
0
,
0
,
0
,
0
};
EXPECT_ANY_THROW
({
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
,
window_outer_corner
,
window_inner_corner
);
});
}
TEST
(
coordinate_iterator
,
cover_count_defaults
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
auto
ci
=
CoordinateIterator
(
space_shape
);
size_t
count
=
0
;
size_t
expected_index
=
0
;
do
{
count
++
;
EXPECT_EQ
(
ci
.
get_current_index
(),
expected_index
);
expected_index
++
;
}
while
(
ci
.
increment
());
EXPECT_EQ
(
count
,
2
*
3
*
5
*
6
);
}
TEST
(
coordinate_iterator
,
cover_count_stride_2
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
);
size_t
count
=
0
;
size_t
expected_index
=
0
;
do
{
count
++
;
EXPECT_EQ
(
ci
.
get_current_index
(),
expected_index
);
expected_index
+=
2
;
}
while
(
ci
.
increment
());
EXPECT_EQ
(
count
,
2
*
3
*
5
*
6
/
2
);
}
#define CEIL_DIV(x, y) (1 + (((x)-1) / (y)))
TEST
(
coordinate_iterator
,
cover_count_stride_uneven
)
{
Shape
space_shape
{
2
,
3
,
5
,
6
};
Strides
strides
{
1
,
2
,
2
,
3
};
auto
ci
=
CoordinateIterator
(
space_shape
,
strides
);
size_t
count
=
0
;
do
{
count
++
;
}
while
(
ci
.
increment
());
EXPECT_EQ
(
count
,
CEIL_DIV
(
2
,
1
)
*
CEIL_DIV
(
3
,
2
)
*
CEIL_DIV
(
5
,
2
)
*
CEIL_DIV
(
6
,
3
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment