test_nan_inputs.py 1.8 KB
Newer Older
xuebingbing's avatar
xuebingbing committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
"""Test how the ufuncs in special handle nan inputs.

"""
from typing import Callable, Dict

import numpy as np
from numpy.testing import assert_array_equal, assert_, suppress_warnings
import pytest
import scipy.special as sc


KNOWNFAILURES: Dict[str, Callable] = {}

POSTPROCESSING: Dict[str, Callable] = {}


def _get_ufuncs():
    ufuncs = []
    ufunc_names = []
    for name in sorted(sc.__dict__):
        obj = sc.__dict__[name]
        if not isinstance(obj, np.ufunc):
            continue
        msg = KNOWNFAILURES.get(obj)
        if msg is None:
            ufuncs.append(obj)
            ufunc_names.append(name)
        else:
            fail = pytest.mark.xfail(run=False, reason=msg)
            ufuncs.append(pytest.param(obj, marks=fail))
            ufunc_names.append(name)
    return ufuncs, ufunc_names


UFUNCS, UFUNC_NAMES = _get_ufuncs()


@pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
def test_nan_inputs(func):
    args = (np.nan,)*func.nin
    with suppress_warnings() as sup:
        # Ignore warnings about unsafe casts from legacy wrappers
        sup.filter(RuntimeWarning,
                   "floating point number truncated to an integer")
        try:
            with suppress_warnings() as sup:
                sup.filter(DeprecationWarning)
                res = func(*args)
        except TypeError:
            # One of the arguments doesn't take real inputs
            return
    if func in POSTPROCESSING:
        res = POSTPROCESSING[func](*res)

    msg = "got {} instead of nan".format(res)
    assert_array_equal(np.isnan(res), True, err_msg=msg)


def test_legacy_cast():
    with suppress_warnings() as sup:
        sup.filter(RuntimeWarning,
                   "floating point number truncated to an integer")
        res = sc.bdtrc(np.nan, 1, 0.5)
        assert_(np.isnan(res))