Unverified Commit 66328c7b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/cpu-call-frame-cleanup

parents 6017ac61 4b44442e
......@@ -100,7 +100,7 @@ namespace ngraph
in_transform_start[in_batch_axis] = batch_index;
in_transform_end[in_batch_axis] = batch_index + 1;
in_transform_start[in_channel_axis] = 0;
in_transform_end[in_channel_axis] = n_in_channels;
in_transform_end[in_channel_axis] = 1;
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{
......@@ -124,7 +124,6 @@ namespace ngraph
{
in_transform_axis_order[i] = i;
}
CoordinateTransform in_transform(in_shape,
in_transform_start,
in_transform_end,
......@@ -150,7 +149,7 @@ namespace ngraph
filter_transform_start[filter_out_channel_axis] = out_channel;
filter_transform_end[filter_out_channel_axis] = out_channel + 1;
filter_transform_start[filter_in_channel_axis] = 0;
filter_transform_end[filter_in_channel_axis] = n_in_channels;
filter_transform_end[filter_in_channel_axis] = 1;
for (size_t i = 2; i < n_spatial_dimensions + 2; i++)
{
......@@ -165,22 +164,34 @@ namespace ngraph
//
// out[O] += in[I] * filter[F].
T result = 0;
float result = 0;
CoordinateTransform::Iterator in_it = in_transform.begin();
CoordinateTransform::Iterator filter_it = filter_transform.begin();
CoordinateTransform::Iterator in_it_end = in_transform.end();
CoordinateTransform::Iterator filter_it_end = filter_transform.end();
size_t in_channel_stride = row_major_strides(in_shape).at(in_channel_axis);
size_t filter_in_channel_stride =
row_major_strides(filter_shape).at(filter_in_channel_axis);
while (in_it != in_it_end && filter_it != filter_it_end)
{
const Coordinate& in_coord = *in_it;
T v = in_transform.has_source_coordinate(in_coord)
? in[in_transform.index(in_coord)]
: 0;
result += v * filter[filter_transform.index(*filter_it)];
if (in_transform.has_source_coordinate(in_coord))
{
size_t in_idx = in_transform.index(in_coord);
const Coordinate& filter_coord = *filter_it;
size_t filter_idx = filter_transform.index(filter_coord);
for (size_t in_channel = 0; in_channel < n_in_channels; ++in_channel)
{
T in_v = in[in_idx];
T f_v = filter[filter_idx];
result += in_v * f_v;
in_idx += in_channel_stride;
filter_idx += filter_in_channel_stride;
}
}
++in_it;
++filter_it;
}
......
......@@ -31,7 +31,10 @@ def random_array_float_literals(length, seed=8086):
for i in range(0, length):
# generate numbers that can be exactly represented in binary
literal_n = np.float32(random.randint(-64, 64)) / 64.0
sig_bits = 6
range_bits = 2
literal_n = np.float32(random.randint(-pow(2, sig_bits-1),
pow(2, sig_bits-1))) / pow(2.0, sig_bits - range_bits)
literals.append(str(literal_n))
return literals
......@@ -128,7 +131,7 @@ def convolution_ref(data_batch, filter, move_strides, filter_dilation, below_pad
slice_tops = (0, 0) + tuple(np.clip(above_pads, None, 0))
slices = list(map(lambda p: slice(
p[0], p[1] if p[1] < 0 else None), zip(slice_bottoms, slice_tops)))
data_batch = data_batch[slices]
data_batch = data_batch[tuple(slices)]
item_count = data_batch.shape[0] # N
ci_count = data_batch.shape[1] # Ci
......@@ -371,7 +374,6 @@ tests = [
(2, 3, 8, 8, 8), (5, 3, 2, 3, 4), (2, 3, 2), (3, 2, 2), (2, 1, 2), (1, 2, 3), (2, 3, 2), "// "),
]
def main():
assert(len(sys.argv) > 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment