#!/usr/bin/env python
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

import sys
import numpy as np
import math
import random
from operator import mul

# Generates an array of random floating point literals of the given length, from a fixed seed.
def random_array_float_literals(length,seed=8086):
    literals = []

    random.seed(seed)

    for i in range(0,length):
        literal_n = random.randint(0,99)
        literal_sign = random.randint(0,1)
        literal_str = ('-' if literal_sign==1 else '') + '.' + ('%02d' % literal_n)
        literals.append(literal_str)

    return literals

# Elementwise addition on tuples.
def tuple_plus(t1,t2):
    assert(len(t1) == len(t2))

    res = ()

    for (x,y) in zip(list(t1),list(t2)):
        res = res + (x+y,)

    return res

# Elementwise multiplication on tuples.
def tuple_times(t1,t2):
    assert(len(t1) == len(t2))

    res = ()

    for (x,y) in zip(list(t1),list(t2)):
        res = res + (x*y,)

    return res

#
# Convolution reference
#
#    Arguments:
#    data_batch       : [N ][Ci][D1]...[Dn], n > 0
#    filter           : [Co][Ci][W1]...[Wn]
#    move_strides     = (s1,...,sn)
#    filter_dilation  = (l1,...,ln)
#    below_pads       = (p1,...,pn)
#    above_pads       = (q1,...,qn)
#    data_dilation    = (t1,...,tn)
#
#    Returns:
#    output_batch     : [N ][Co][D'1]...[D'n]
#
# Where the D's are computed according to TensorFlow-style "valid" convolution rules, but *after* padding.
# See https://www.tensorflow.org/api_docs/python/tf/nn/convolution.
#
def convolution_ref(data_batch, filter, move_strides, filter_dilation, below_pads, above_pads, data_dilation):
    assert(len(data_batch.shape) == len(filter.shape))
    assert(len(data_batch.shape) > 2)
    assert(len(data_batch.shape) <= 6)
    assert(data_batch.shape[1] == filter.shape[1])
    assert(len(move_strides) == len(data_batch.shape) - 2)
    assert(len(filter_dilation) == len(data_batch.shape) - 2)
    assert(len(data_dilation) == len(data_batch.shape) - 2)

    # dilate the input batch
    new_item_shape = (np.array(data_batch.shape[2:]) - 1) * data_dilation + 1
    new_data_batch_shape = list(np.array(data_batch.shape[:2])) + list(new_item_shape)
    new_data_batch = np.zeros(new_data_batch_shape)

    for n in range(0, new_data_batch_shape[0]) :
        for c in range(0, new_data_batch_shape[1]) :
            if new_data_batch.ndim == 3:
                new_data_batch[n, c, 0::data_dilation[0]] = data_batch[n][c]
            elif new_data_batch.ndim == 4:
                new_data_batch[n, c, 0::data_dilation[0], 0::data_dilation[1]] = data_batch[n][c]
            elif new_data_batch.ndim == 5:
                new_data_batch[n, c, 0::data_dilation[0], 0::data_dilation[1], 0::data_dilation[2]] = data_batch[n][c]
            elif new_data_batch.ndim == 6:
                new_data_batch[n, c, 0::data_dilation[0], 0::data_dilation[1], 0::data_dilation[2], 0::data_dilation[3]] = data_batch[n][c]
            else:
                assert(False)

    data_batch = new_data_batch

    # Pad the input batch wherever the pads are positive.
    below_pads_pos = (0,0) + tuple(np.clip(below_pads,0,None))  # Have to add values for the spatial and channel dims.
    above_pads_pos = (0,0) + tuple(np.clip(above_pads,0,None))  # Have to add values for the spatial and channel dims.
    data_batch = np.pad(data_batch, zip(below_pads_pos,above_pads_pos), mode='constant', constant_values=0)

    # Slice the input batch wherever the pads are negative.
    slice_bottoms = (0,0) + tuple (-np.clip(below_pads,None,0))
    slice_tops = (0,0) + tuple (np.clip(above_pads,None,0))
    slices = map(lambda p: slice(p[0],p[1] if p[1] < 0 else None),zip(slice_bottoms,slice_tops))
    data_batch = data_batch[slices]

    item_count = data_batch.shape[0]               # N
    ci_count = data_batch.shape[1]                 # Ci
    co_count = filter.shape[0]                     # Co
    input_item_shape = list(data_batch.shape[2:])  # D1, ..., Dn
    window_virtual_shape = list(filter.shape[2:])  # W1, ..., Wn

    # This is not used in computation but we will calculate it for a check to make sure the window fits.
    window_physical_shape = []
    for (d_in,d_virt,dil) in zip(input_item_shape,window_virtual_shape,filter_dilation):
        d_phys = (d_virt - 1) * dil + 1
        assert(d_phys <= input_item_shape)
        window_physical_shape.append(d_phys)

    output_item_shape = []  # D'1,...,D'n
    for (d_in,d_win,dil,mov) in zip (input_item_shape,window_virtual_shape,filter_dilation,move_strides):
        d_out = int(math.ceil((float(d_in) - (float(d_win) - 1.0) * float(dil))/float(mov))) # Formula is taken from TF's definition for VALID convolution.
        assert(d_out > 0)
        output_item_shape.append(d_out)

    output_shape = [item_count,co_count]+output_item_shape # N,Co,D'1,...,D'n
    output_batch = np.zeros(output_shape)

    # Walk over the output batch space.
    output_it = np.nditer(output_batch, flags=['multi_index'])
    while not output_it.finished:
        # Break up the output coordinate to figure out where we are in terms of batch index, output channel, and spatial position.
        output_index = output_it.multi_index
        item, co, output_pos = output_index[0], output_index[1], output_index[2:]

        # Walk over the filter for the current output channel.
        filter_it = np.nditer(filter[co], flags=['multi_index'])
        while not filter_it.finished:
            # Break up the filter coordinate to figure out where we are in terms of input channel and filter shape position.
            filter_index = filter_it.multi_index
            ci, filter_pos = filter_index[0], filter_index[1:]

            # Build up the coordinate within the space N,Ci,D1,...,Dn that we need to read from in the input batch.
            input_index = (item,ci) + (tuple_plus(tuple_times(output_pos,move_strides),tuple_times(filter_pos,filter_dilation)))

            # Add to the sum-of-products.
            output_batch[output_index] = output_batch[output_index] + filter[(co,) + filter_index] * data_batch[input_index]

            filter_it.iternext()

        output_it.iternext()

    return output_batch

def shape_str(shape):
    result = ''
    first = True
    for d in shape:
        if first:
            result = ('%d' % d)
            first = False
        else:
            result = result + (',%d' % d)
    return result

def scalar_str(x):
    result = ('%.1000g' % x)
    # This next part is a bit stupid.
    if "." not in result and "e" not in result:
        result = result + ".0f"
    else:
        result = "%.8ff" % float(result)
    return result

def data_str(data):
    result = ''
    first = True
    for x in np.nditer(data):
        if first:
            result = scalar_str(x)
            first = False
        else:
            result = result + ',' + scalar_str(x)
    return result

def emit_test(t,f):
    test_name, input_batch_shape, filters_shape, move_strides, filter_dilation, below_pads, above_pads, data_dilation, bprop = t

    input_batch_literals = random_array_float_literals(reduce(mul,input_batch_shape))
    filters_literals = random_array_float_literals(reduce(mul,filters_shape))

    input_batch_array = np.array(map(lambda s: np.float32(s),input_batch_literals))
    input_batch_array.shape = input_batch_shape
    filters_array = np.array(map(lambda s: np.float32(s),filters_literals))
    filters_array.shape = filters_shape

    print ("Generating convolution test '%s'..." % test_name)

    output_batch_data = convolution_ref(input_batch_array,filters_array,move_strides,filter_dilation,below_pads,above_pads,data_dilation)

    template = '''
// !!!!!!!!!!!!!! THIS FILE IS AUTOGENERATED OUTSIDE OF THE BUILD PROCESS !!!!!!!!!!!!!!
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DO NOT EDIT THIS FILE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//
// DO NOT EDIT THIS FILE. If you want to add new tests, you should edit
//  test/ref_generators/generate_convolution_ref.py and regenerate this file.
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DO NOT EDIT THIS FILE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// !!!!!!!!!!!!!! THIS FILE IS AUTOGENERATED OUTSIDE OF THE BUILD PROCESS !!!!!!!!!!!!!!
TEST (${BACKEND_NAME}, %s)
{
    SKIP_TEST_FOR("GPU", "${BACKEND_NAME}");
    Shape shape_a{%s};
    Shape shape_b{%s};
    Shape shape_r{%s};
    auto make_graph = [shape_a, shape_b] {
        auto A = make_shared<op::Parameter>(element::f32, shape_a);
        auto B = make_shared<op::Parameter>(element::f32, shape_b);
        return make_shared<Function>(make_shared<op::Convolution>(A, B,
                                                                  Strides{%s},        // move_strides
                                                                  Strides{%s},        // filter_dilation
                                                                  CoordinateDiff{%s}, // below_pads
                                                                  CoordinateDiff{%s}, // above_pads
                                                                  Strides{%s}),       // data_dilation
                                     op::ParameterVector{A, B});
    };

    auto backend = runtime::Backend::create("${BACKEND_NAME}");
    auto function = make_graph();

    // Create some tensors for input/output
    auto a = backend->create_tensor(element::f32, shape_a);
    copy_data(a, vector<float>{%s});
    auto b = backend->create_tensor(element::f32, shape_b);
    copy_data(b, vector<float>{%s});
    auto result = backend->create_tensor(element::f32, shape_r);

    vector<float> expected_result{%s};

    backend->call(function, {result}, {a, b});
    EXPECT_TRUE(test::all_close<float>(vector<float>{expected_result}, read_vector<float>(result), 1.0e-4f, 1.0e-6f));
    // only test backprop for certain cases as it takes significant compute resources
    if(%s) {
        EXPECT_TRUE(autodiff_numeric_compare<float>(backend, make_graph, {a, b}, .01f, .01f));
    }
}
'''
    f.write (template % (test_name,
                         shape_str(input_batch_shape),
                         shape_str(filters_shape),
                         shape_str(output_batch_data.shape),
                         shape_str(move_strides),
                         shape_str(filter_dilation),
                         shape_str(below_pads),
                         shape_str(above_pads),
                         shape_str(data_dilation),
                         ",".join(map(lambda s: "%.8ff" % float(s),input_batch_literals)),
                         ",".join(map(lambda s: "%.8ff" % float(s),filters_literals)),
                         data_str(output_batch_data),
                         bprop));

#                                                                              filter                                      data
#         test name                                batch shape   filts shape   stride    dilation  below-pads  above-pads  dilation   bprop?
tests = [
         ("convolution_2d_1item",                  (1,1,3,5),    (2,1,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (1,1),     "true"),
         ("convolution_2d_1item_padded_1_1x1_1",   (1,1,3,5),    (2,1,2,2),    (1,1),    (1,1),    (1,1),      (1,1),      (1,1),     "true"),
         ("convolution_2d_1item_padded_2_3x4_5",   (1,1,3,5),    (2,1,2,2),    (1,1),    (1,1),    (2,3),      (4,5),      (1,1),     "true"),
         ("convolution_2d_2items",                 (2,1,3,5),    (2,1,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (1,1),     "true"),
         ("convolution_2d_2items_strided",         (2,1,3,5),    (2,1,2,2),    (2,2),    (1,1),    (0,0),      (0,0),      (1,1),     "true"),
         ("convolution_2d_2items_strided_padded",  (2,1,3,5),    (2,1,2,2),    (2,2),    (1,1),    (4,2),      (5,7),      (1,1),     "true"),
         ("convolution_2d_2items_strided_padded_same",
                                                   (2,1,3,5),    (2,1,2,2),    (2,2),    (1,1),    (2,2),      (2,2),      (1,1),     "true"),
         ("convolution_2d_2items_dilated",         (2,1,3,5),    (2,1,2,2),    (1,1),    (2,2),    (0,0),      (0,0),      (1,1),     "true"),
         ("convolution_2d_2items_dilated_padded",  (2,1,3,5),    (2,1,2,2),    (1,1),    (2,2),    (4,2),      (5,7),      (1,1),     "true"),
         ("convolution_3d_2items",                 (2,1,3,5,8),  (2,1,2,2,3),  (1,1,1),  (1,1,1),  (0,0,0),    (0,0,0),    (1,1,1),   "true"),
         ("convolution_4d_2items",                 (2,1,3,5,8,7),(2,1,2,2,3,1),(1,1,1,1),(1,1,1,1),(0,0,0,0),  (0,0,0,0),  (1,1,1,1), "false"),
         ("convolution_4d_4items",                 (4,3,3,5,8,7),(4,3,2,2,3,1),(1,1,1,1),(1,1,1,1),(0,0,0,0),  (0,0,0,0),  (1,1,1,1), "false"),
         ("convolution_4d_4items_padded_neg",      (4,3,3,5,8,7),(4,3,2,2,3,1),(1,1,1,1),(1,1,1,1),(-1,2,-3,2),(1,0,0,-3), (1,1,1,1), "false"),
         ("convolution_4d_4items_strided",         (4,3,3,5,8,7),(4,3,2,2,3,1),(2,1,3,2),(1,1,1,1),(0,0,0,0),  (0,0,0,0),  (1,1,1,1), "false"),
         ("convolution_4d_4items_dilated",         (4,3,3,5,8,7),(4,3,2,2,3,1),(1,1,1,1),(2,1,3,2),(0,0,0,0),  (0,0,0,0),  (1,1,1,1), "false"),
         ("convolution_4d_4items_strided_dilated", (4,3,8,8,8,8),(4,3,2,2,3,1),(3,2,2,3),(2,1,3,2),(0,0,0,0),  (0,0,0,0),  (1,1,1,1), "false"),
         ("convolution_4d_4items_strided_dilated_padded",
                                                   (4,3,8,8,8,8),(4,3,2,2,3,1),(3,2,2,3),(2,1,3,2),(2,4,6,8),  (1,3,5,7),  (1,1,1,1), "false"),
         ("convolution_4d_4items_strided_dilated_padded_neg",
                                                   (4,3,8,8,8,8),(4,3,2,2,3,1),(3,2,2,3),(2,1,3,2),(-2,4,0,5), (1,3,-1,-4),(1,1,1,1), "false"),
         ("convolution_4d_4items_strided_dilated_padded_same",
                                                   (4,3,8,8,8,8),(4,3,2,2,3,1),(3,2,2,3),(2,1,3,2),(3,3,3,3),  (3,3,3,3),  (1,1,1,1), "false"),
         ("convolution_2d_1item_1o1i_data_dilated",(1,1,3,5),    (1,1,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "true"),
         ("convolution_2d_1item_2o1i_data_dilated",(1,1,3,5),    (2,1,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "true"),
         ("convolution_2d_1item_2o2i_data_dilated",(1,2,3,5),    (2,2,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "true"),
         ("convolution_2d_1item_5o3i_data_dilated",(1,3,3,5),    (5,3,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "true"),
         ("convolution_2d_2item_5o3i_data_dilated",(2,3,3,5),    (5,3,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "true"),
         ("convolution_2d_8item_large_5o3i_data_dilated",
                                                   (8,3,16,16),  (5,3,2,2),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "false"),
         ("convolution_2d_8item_large_5o3i_uneven_filter_data_dilated",
                                                   (8,3,16,16),  (5,3,2,3),    (1,1),    (1,1),    (0,0),      (0,0),      (2,2),     "false"),
         ("convolution_2d_8item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated",
                                                   (8,3,16,16),  (5,3,2,3),    (1,1),    (1,1),    (0,0),      (0,0),      (2,3),     "false"),
         ("convolution_3d_2item_large_5o3i_uneven_filter_uneven_data_dilation_data_dilated",
                                                   (2,3,8,8,8),  (5,3,2,3,4),  (1,1,1),  (1,1,1),  (0,0,0),    (0,0,0),    (2,3,2),   "false"),
         ("convolution_3d_1item_large_5o3i_padded_uneven_filter_uneven_data_dilation_data_dilated",
                                                   (1,3,8,8,8),  (5,3,2,3,4),  (1,1,1),  (1,1,1),  (2,1,2),    (1,2,3),    (2,3,2),   "false"),
         ("convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_data_dilated",
                                                   (2,3,8,8,8),  (5,3,2,3,4),  (2,3,2),  (1,1,1),  (2,1,2),    (1,2,3),    (2,3,2),   "false"),
         ("convolution_3d_2item_large_5o3i_padded_strided_uneven_filter_uneven_data_dilation_filter_dilated_data_dilated",
                                                   (2,3,8,8,8),  (5,3,2,3,4),  (2,3,2),  (3,2,2),  (2,1,2),    (1,2,3),    (2,3,2),   "false"),
        ]

def main():
    assert(len(sys.argv)>1)

    f = open(sys.argv[1],'w')
    f.write('''
// clang-format off

/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
//
// !!!!!!!!!!!!!! THIS FILE IS AUTOGENERATED OUTSIDE OF THE BUILD PROCESS !!!!!!!!!!!!!!
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DO NOT EDIT THIS FILE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//
// It takes quite a while to compute the results.
//
// DO NOT EDIT THIS FILE. If you want to add new tests, you should edit
//  test/ref_generators/generate_convolution_ref.py and regenerate this file.
//
// To regenerate:
//
//   $ cd <ngraph source dir>/test
//   $ ./update_reference.sh
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DO NOT EDIT THIS FILE !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// !!!!!!!!!!!!!! THIS FILE IS AUTOGENERATED OUTSIDE OF THE BUILD PROCESS !!!!!!!!!!!!!!
//

#include <cmath>

#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "util/test_tools.hpp"
#include "util/autodiff/numeric_compare.hpp"

using namespace std;
using namespace ngraph;

''')

    for t in tests:
        emit_test(t,f)

    f.write('''
// clang-format on
''')

    f.close()

if __name__ == "__main__":
    main()