from __future__ import division, print_function, absolute_import

from itertools import product

from numpy.testing import assert_allclose
import pytest

from scipy import special
from scipy.special import cython_special


int_points = [-10, -1, 1, 10]
real_points = [-10.0, -1.0, 1.0, 10.0]
complex_points = [complex(*tup) for tup in product(real_points, repeat=2)]


CYTHON_SIGNATURE_MAP = {
    'f': 'float',
    'd': 'double',
    'g': 'long double',
    'F': 'float complex',
    'D': 'double complex',
    'G': 'long double complex',
    'i':'int',
    'l': 'long'
}


TEST_POINTS = {
    'f': real_points,
    'd': real_points,
    'g': real_points,
    'F': complex_points,
    'D': complex_points,
    'G': complex_points,
    'i': int_points,
    'l': int_points,
}


PARAMS = [
    (special.agm, cython_special.agm, ('dd',), None),
    (special.airy, cython_special._airy_pywrap, ('d', 'D'), None),
    (special.airye, cython_special._airye_pywrap, ('d', 'D'), None),
    (special.bdtr, cython_special.bdtr, ('lld', 'ddd'), None),
    (special.bdtrc, cython_special.bdtrc, ('lld', 'ddd'), None),
    (special.bdtri, cython_special.bdtri, ('lld', 'ddd'), None),
    (special.bdtrik, cython_special.bdtrik, ('ddd',), None),
    (special.bdtrin, cython_special.bdtrin, ('ddd',), None),
    (special.bei, cython_special.bei, ('d',), None),
    (special.beip, cython_special.beip, ('d',), None),
    (special.ber, cython_special.ber, ('d',), None),
    (special.berp, cython_special.berp, ('d',), None),
    (special.besselpoly, cython_special.besselpoly, ('ddd',), None),
    (special.beta, cython_special.beta, ('dd',), None),
    (special.betainc, cython_special.betainc, ('ddd',), None),
    (special.betaincinv, cython_special.betaincinv, ('ddd',), None),
    (special.betaln, cython_special.betaln, ('dd',), None),
    (special.binom, cython_special.binom, ('dd',), None),
    (special.boxcox, cython_special.boxcox, ('dd',), None),
    (special.boxcox1p, cython_special.boxcox1p, ('dd',), None),
    (special.btdtr, cython_special.btdtr, ('ddd',), None),
    (special.btdtri, cython_special.btdtri, ('ddd',), None),
    (special.btdtria, cython_special.btdtria, ('ddd',), None),
    (special.btdtrib, cython_special.btdtrib, ('ddd',), None),
    (special.cbrt, cython_special.cbrt, ('d',), None),
    (special.chdtr, cython_special.chdtr, ('dd',), None),
    (special.chdtrc, cython_special.chdtrc, ('dd',), None),
    (special.chdtri, cython_special.chdtri, ('dd',), None),
    (special.chdtriv, cython_special.chdtriv, ('dd',), None),
    (special.chndtr, cython_special.chndtr, ('ddd',), None),
    (special.chndtridf, cython_special.chndtridf, ('ddd',), None),
    (special.chndtrinc, cython_special.chndtrinc, ('ddd',), None),
    (special.chndtrix, cython_special.chndtrix, ('ddd',), None),
    (special.cosdg, cython_special.cosdg, ('d',), None),
    (special.cosm1, cython_special.cosm1, ('d',), None),
    (special.cotdg, cython_special.cotdg, ('d',), None),
    (special.dawsn, cython_special.dawsn, ('d', 'D'), None),
    (special.ellipe, cython_special.ellipe, ('d',), None),
    (special.ellipeinc, cython_special.ellipeinc, ('dd',), None),
    (special.ellipj, cython_special._ellipj_pywrap, ('dd',), None),
    (special.ellipkinc, cython_special.ellipkinc, ('dd',), None),
    (special.ellipkm1, cython_special.ellipkm1, ('d',), None),
    (special.entr, cython_special.entr, ('d',), None),
    (special.erf, cython_special.erf, ('d', 'D'), None),
    (special.erfc, cython_special.erfc, ('d', 'D'), None),
    (special.erfcx, cython_special.erfcx, ('d', 'D'), None),
    (special.erfi, cython_special.erfi, ('d', 'D'), None),
    (special.eval_chebyc, cython_special.eval_chebyc, ('dd', 'dD', 'ld'), None),
    (special.eval_chebys, cython_special.eval_chebys, ('dd', 'dD', 'ld'),
     'd and l differ for negative int'),
    (special.eval_chebyt, cython_special.eval_chebyt, ('dd', 'dD', 'ld'),
     'd and l differ for negative int'),
    (special.eval_chebyu, cython_special.eval_chebyu, ('dd', 'dD', 'ld'),
     'd and l differ for negative int'),
    (special.eval_gegenbauer, cython_special.eval_gegenbauer, ('ddd', 'ddD', 'ldd'),
     'd and l differ for negative int'),
    (special.eval_genlaguerre, cython_special.eval_genlaguerre, ('ddd', 'ddD', 'ldd'),
     'd and l differ for negative int'),
    (special.eval_hermite, cython_special.eval_hermite, ('ld',), None),
    (special.eval_hermitenorm, cython_special.eval_hermitenorm, ('ld',), None),
    (special.eval_jacobi, cython_special.eval_jacobi, ('dddd', 'dddD', 'lddd'),
     'd and l differ for negative int'),
    (special.eval_laguerre, cython_special.eval_laguerre, ('dd', 'dD', 'ld'),
     'd and l differ for negative int'),
    (special.eval_legendre, cython_special.eval_legendre, ('dd', 'dD', 'ld'), None),
    (special.eval_sh_chebyt, cython_special.eval_sh_chebyt, ('dd', 'dD', 'ld'), None),
    (special.eval_sh_chebyu, cython_special.eval_sh_chebyu, ('dd', 'dD', 'ld'),
     'd and l differ for negative int'),
    (special.eval_sh_jacobi, cython_special.eval_sh_jacobi, ('dddd', 'dddD', 'lddd'),
     'd and l differ for negative int'),
    (special.eval_sh_legendre, cython_special.eval_sh_legendre, ('dd', 'dD', 'ld'), None),
    (special.exp1, cython_special.exp1, ('d', 'D'), None),
    (special.exp10, cython_special.exp10, ('d',), None),
    (special.exp2, cython_special.exp2, ('d',), None),
    (special.expi, cython_special.expi, ('d', 'D'), None),
    (special.expit, cython_special.expit, ('f', 'd', 'g'), None),
    (special.expm1, cython_special.expm1, ('d', 'D'), None),
    (special.expn, cython_special.expn, ('ld', 'dd'), None),
    (special.exprel, cython_special.exprel, ('d',), None),
    (special.fdtr, cython_special.fdtr, ('ddd',), None),
    (special.fdtrc, cython_special.fdtrc, ('ddd',), None),
    (special.fdtri, cython_special.fdtri, ('ddd',), None),
    (special.fdtridfd, cython_special.fdtridfd, ('ddd',), None),
    (special.fresnel, cython_special._fresnel_pywrap, ('d', 'D'), None),
    (special.gamma, cython_special.gamma, ('d', 'D'), None),
    (special.gammainc, cython_special.gammainc, ('dd',), None),
    (special.gammaincc, cython_special.gammaincc, ('dd',), None),
    (special.gammainccinv, cython_special.gammainccinv, ('dd',), None),
    (special.gammaincinv, cython_special.gammaincinv, ('dd',), None),
    (special.gammaln, cython_special.gammaln, ('d',), None),
    (special.gammasgn, cython_special.gammasgn, ('d',), None),
    (special.gdtr, cython_special.gdtr, ('ddd',), None),
    (special.gdtrc, cython_special.gdtrc, ('ddd',), None),
    (special.gdtria, cython_special.gdtria, ('ddd',), None),
    (special.gdtrib, cython_special.gdtrib, ('ddd',), None),
    (special.gdtrix, cython_special.gdtrix, ('ddd',), None),
    (special.hankel1, cython_special.hankel1, ('dD',), None),
    (special.hankel1e, cython_special.hankel1e, ('dD',), None),
    (special.hankel2, cython_special.hankel2, ('dD',), None),
    (special.hankel2e, cython_special.hankel2e, ('dD',), None),
    (special.huber, cython_special.huber, ('dd',), None),
    (special.hyp0f1, cython_special.hyp0f1, ('dd', 'dD'), None),
    (special.hyp1f1, cython_special.hyp1f1, ('ddd', 'ddD'), None),
    (special.hyp1f2, cython_special._hyp1f2_pywrap, ('dddd',), None),
    (special.hyp2f0, cython_special._hyp2f0_pywrap, ('dddl', 'dddd'), None),
    (special.hyp2f1, cython_special.hyp2f1, ('dddd', 'dddD'), None),
    (special.hyp3f0, cython_special._hyp3f0_pywrap, ('dddd',), None),
    (special.hyperu, cython_special.hyperu, ('ddd',), None),
    (special.i0, cython_special.i0, ('d',), None),
    (special.i0e, cython_special.i0e, ('d',), None),
    (special.i1, cython_special.i1, ('d',), None),
    (special.i1e, cython_special.i1e, ('d',), None),
    (special.inv_boxcox, cython_special.inv_boxcox, ('dd',), None),
    (special.inv_boxcox1p, cython_special.inv_boxcox1p, ('dd',), None),
    (special.it2i0k0, cython_special._it2i0k0_pywrap, ('d',), None),
    (special.it2j0y0, cython_special._it2j0y0_pywrap, ('d',), None),
    (special.it2struve0, cython_special.it2struve0, ('d',), None),
    (special.itairy, cython_special._itairy_pywrap, ('d',), None),
    (special.iti0k0, cython_special._iti0k0_pywrap, ('d',), None),
    (special.itj0y0, cython_special._itj0y0_pywrap, ('d',), None),
    (special.itmodstruve0, cython_special.itmodstruve0, ('d',), None),
    (special.itstruve0, cython_special.itstruve0, ('d',), None),
    (special.iv, cython_special.iv, ('dd', 'dD'), None),
    (special.ive, cython_special.ive, ('dd', 'dD'), None),
    (special.j0, cython_special.j0, ('d',), None),
    (special.j1, cython_special.j1, ('d',), None),
    (special.jv, cython_special.jv, ('dd', 'dD'), None),
    (special.jve, cython_special.jve, ('dd', 'dD'), None),
    (special.k0, cython_special.k0, ('d',), None),
    (special.k0e, cython_special.k0e, ('d',), None),
    (special.k1, cython_special.k1, ('d',), None),
    (special.k1e, cython_special.k1e, ('d',), None),
    (special.kei, cython_special.kei, ('d',), None),
    (special.keip, cython_special.keip, ('d',), None),
    (special.kelvin, cython_special._kelvin_pywrap, ('d',), None),
    (special.ker, cython_special.ker, ('d',), None),
    (special.kerp, cython_special.kerp, ('d',), None),
    (special.kl_div, cython_special.kl_div, ('dd',), None),
    (special.kn, cython_special.kn, ('ld', 'dd'), None),
    (special.kolmogi, cython_special.kolmogi, ('d',), None),
    (special.kolmogorov, cython_special.kolmogorov, ('d',), None),
    (special.kv, cython_special.kv, ('dd', 'dD'), None),
    (special.kve, cython_special.kve, ('dd', 'dD'), None),
    (special.log1p, cython_special.log1p, ('d', 'D'), None),
    (special.log_ndtr, cython_special.log_ndtr, ('d', 'D'), None),
    (special.loggamma, cython_special.loggamma, ('D',), None),
    (special.logit, cython_special.logit, ('f', 'd', 'g'), None),
    (special.lpmv, cython_special.lpmv, ('ddd',), None),
    (special.mathieu_a, cython_special.mathieu_a, ('dd',), None),
    (special.mathieu_b, cython_special.mathieu_b, ('dd',), None),
    (special.mathieu_cem, cython_special._mathieu_cem_pywrap, ('ddd',), None),
    (special.mathieu_modcem1, cython_special._mathieu_modcem1_pywrap, ('ddd',), None),
    (special.mathieu_modcem2, cython_special._mathieu_modcem2_pywrap, ('ddd',), None),
    (special.mathieu_modsem1, cython_special._mathieu_modsem1_pywrap, ('ddd',), None),
    (special.mathieu_modsem2, cython_special._mathieu_modsem2_pywrap, ('ddd',), None),
    (special.mathieu_sem, cython_special._mathieu_sem_pywrap, ('ddd',), None),
    (special.modfresnelm, cython_special._modfresnelm_pywrap, ('d',), None),
    (special.modfresnelp, cython_special._modfresnelp_pywrap, ('d',), None),
    (special.modstruve, cython_special.modstruve, ('dd',), None),
    (special.nbdtr, cython_special.nbdtr, ('lld', 'ddd'), None),
    (special.nbdtrc, cython_special.nbdtrc, ('lld', 'ddd'), None),
    (special.nbdtri, cython_special.nbdtri, ('lld', 'ddd'), None),
    (special.nbdtrik, cython_special.nbdtrik, ('ddd',), None),
    (special.nbdtrin, cython_special.nbdtrin, ('ddd',), None),
    (special.ncfdtr, cython_special.ncfdtr, ('dddd',), None),
    (special.ncfdtri, cython_special.ncfdtri, ('dddd',), None),
    (special.ncfdtridfd, cython_special.ncfdtridfd, ('dddd',), None),
    (special.ncfdtridfn, cython_special.ncfdtridfn, ('dddd',), None),
    (special.ncfdtrinc, cython_special.ncfdtrinc, ('dddd',), None),
    (special.nctdtr, cython_special.nctdtr, ('ddd',), None),
    (special.nctdtridf, cython_special.nctdtridf, ('ddd',), None),
    (special.nctdtrinc, cython_special.nctdtrinc, ('ddd',), None),
    (special.nctdtrit, cython_special.nctdtrit, ('ddd',), None),
    (special.ndtr, cython_special.ndtr, ('d', 'D'), None),
    (special.ndtri, cython_special.ndtri, ('d',), None),
    (special.nrdtrimn, cython_special.nrdtrimn, ('ddd',), None),
    (special.nrdtrisd, cython_special.nrdtrisd, ('ddd',), None),
    (special.obl_ang1, cython_special._obl_ang1_pywrap, ('dddd',), None),
    (special.obl_ang1_cv, cython_special._obl_ang1_cv_pywrap, ('ddddd',), None),
    (special.obl_cv, cython_special.obl_cv, ('ddd',), None),
    (special.obl_rad1, cython_special._obl_rad1_pywrap, ('dddd',), "see gh-6211"),
    (special.obl_rad1_cv, cython_special._obl_rad1_cv_pywrap, ('ddddd',), "see gh-6211"),
    (special.obl_rad2, cython_special._obl_rad2_pywrap, ('dddd',), "see gh-6211"),
    (special.obl_rad2_cv, cython_special._obl_rad2_cv_pywrap, ('ddddd',), "see gh-6211"),
    (special.pbdv, cython_special._pbdv_pywrap, ('dd',), None),
    (special.pbvv, cython_special._pbvv_pywrap, ('dd',), None),
    (special.pbwa, cython_special._pbwa_pywrap, ('dd',), None),
    (special.pdtr, cython_special.pdtr, ('ld', 'dd'), None),
    (special.pdtrc, cython_special.pdtrc, ('ld', 'dd'), None),
    (special.pdtri, cython_special.pdtri, ('ld', 'dd'), None),
    (special.pdtrik, cython_special.pdtrik, ('dd',), None),
    (special.poch, cython_special.poch, ('dd',), None),
    (special.pro_ang1, cython_special._pro_ang1_pywrap, ('dddd',), None),
    (special.pro_ang1_cv, cython_special._pro_ang1_cv_pywrap, ('ddddd',), None),
    (special.pro_cv, cython_special.pro_cv, ('ddd',), None),
    (special.pro_rad1, cython_special._pro_rad1_pywrap, ('dddd',), "see gh-6211"),
    (special.pro_rad1_cv, cython_special._pro_rad1_cv_pywrap, ('ddddd',), "see gh-6211"),
    (special.pro_rad2, cython_special._pro_rad2_pywrap, ('dddd',), "see gh-6211"),
    (special.pro_rad2_cv, cython_special._pro_rad2_cv_pywrap, ('ddddd',), "see gh-6211"),
    (special.pseudo_huber, cython_special.pseudo_huber, ('dd',), None),
    (special.psi, cython_special.psi, ('d', 'D'), None),
    (special.radian, cython_special.radian, ('ddd',), None),
    (special.rel_entr, cython_special.rel_entr, ('dd',), None),
    (special.rgamma, cython_special.rgamma, ('d', 'D'), None),
    (special.round, cython_special.round, ('d',), None),
    (special.shichi, cython_special._shichi_pywrap, ('d', 'D'), None),
    (special.sici, cython_special._sici_pywrap, ('d', 'D'), None),
    (special.sindg, cython_special.sindg, ('d',), None),
    (special.smirnov, cython_special.smirnov, ('ld', 'dd'), None),
    (special.smirnovi, cython_special.smirnovi, ('ld', 'dd'), None),
    (special.spence, cython_special.spence, ('d', 'D'), None),
    (special.sph_harm, cython_special.sph_harm, ('lldd', 'dddd'), None),
    (special.stdtr, cython_special.stdtr, ('dd',), None),
    (special.stdtridf, cython_special.stdtridf, ('dd',), None),
    (special.stdtrit, cython_special.stdtrit, ('dd',), None),
    (special.struve, cython_special.struve, ('dd',), None),
    (special.tandg, cython_special.tandg, ('d',), None),
    (special.tklmbda, cython_special.tklmbda, ('dd',), None),
    (special.wofz, cython_special.wofz, ('D',), None),
    (special.wrightomega, cython_special.wrightomega, ('D',), None),
    (special.xlog1py, cython_special.xlog1py, ('dd', 'DD'), None),
    (special.xlogy, cython_special.xlogy, ('dd', 'DD'), None),
    (special.y0, cython_special.y0, ('d',), None),
    (special.y1, cython_special.y1, ('d',), None),
    (special.yn, cython_special.yn, ('ld', 'dd'), None),
    (special.yv, cython_special.yv, ('dd', 'dD'), None),
    (special.yve, cython_special.yve, ('dd', 'dD'), None),
    (special.zetac, cython_special.zetac, ('d',), None),
    (special.owens_t, cython_special.owens_t, ('dd',), None)
]


IDS = [x[0].__name__ for x in PARAMS]


def _generate_test_points(typecodes):
    axes = tuple(map(lambda x: TEST_POINTS[x], typecodes))
    pts = list(product(*axes))
    return pts


def test_cython_api_completeness():
    # Check that everything is tested
    skip = []
    for name in dir(cython_special):
        func = getattr(cython_special, name)
        if callable(func) and not (name.startswith('_bench') or name in skip):
            for _, cyfun, _, _ in PARAMS:
                if cyfun is func:
                    break
            else:
                raise RuntimeError("{} missing from tests!".format(name))


@pytest.mark.parametrize("param", PARAMS, ids=IDS)
def test_cython_api(param):
    pyfunc, cyfunc, specializations, knownfailure = param
    if knownfailure:
        pytest.xfail(reason=knownfailure)

    # Check which parameters are expected to be fused types
    values = [set() for code in specializations[0]]
    for typecodes in specializations:
        for j, v in enumerate(typecodes):
            values[j].add(v)
    seen = set()
    is_fused_code = [False] * len(values)
    for j, v in enumerate(values):
        vv = tuple(sorted(v))
        if vv in seen:
            continue
        is_fused_code[j] = (len(v) > 1)
        seen.add(vv)

    # Check results
    for typecodes in specializations:
        # Pick the correct specialized function
        signature = []
        for j, code in enumerate(typecodes):
            if is_fused_code[j]:
                signature.append(CYTHON_SIGNATURE_MAP[code])

        if signature:
            cy_spec_func = cyfunc[tuple(signature)]
        else:
            signature = None
            cy_spec_func = cyfunc

        # Test it
        pts = _generate_test_points(typecodes)
        for pt in pts:
            pyval = pyfunc(*pt)
            cyval = cy_spec_func(*pt)
            assert_allclose(cyval, pyval, err_msg="{} {} {}".format(pt, typecodes, signature))