Benchopt Test definition

Benchopt Test definition#

import pytest
import numpy as np

from benchopt.stopping_criterion import StoppingCriterion
from benchopt.stopping_criterion import SAMPLING_STRATEGIES
from benchopt.utils.dynamic_modules import _get_module_from_file


def test_dataset_class(benchmark, dataset_class):
    """Check that all dataset_class respects the public API"""

    # Check that the dataset_class exposes a name
    assert hasattr(dataset_class, 'name'), "All dataset should expose a name"
    assert isinstance(dataset_class.name, str), (
        "The dataset's name should be a string")

    # Ensure that the dataset exposes a `get_data` function
    # that is callable
    dataset = dataset_class.get_instance()
    assert hasattr(dataset, 'get_data'), (
        "All dataset should implement get_data"
    )
    assert callable(dataset.get_data), (
        "dataset.get_data should be a callable"
    )


def test_dataset_get_data(benchmark, dataset_class):
    """Check that all installed dataset_class.get_data return the right result
    """

    # skip the test if the dataset is not installed
    if not dataset_class.is_installed():
        pytest.skip("Dataset is not installed")

    dataset = dataset_class.get_instance()

    data = dataset._get_data()
    assert isinstance(data, (tuple, dict)), (
        "Output of get_data should be a 2-tuple or a dict."
    )

    assert isinstance(data, dict), (
        f"The returned data from get_data should be a dict. Got {data}."
    )


def test_benchmark_objective(benchmark, objective_class):
    # check that the result of the objective function is compatible with
    # benchopt, does not contain `objective_name` and is not empty.
    objective = objective_class.get_instance()
    dataset_class, test_params = benchmark.get_test_dataset()
    dataset = dataset_class.get_instance(**test_params[0])

    # get one value for the objective, with the test_dataset
    objective.set_dataset(dataset)
    result = objective._get_one_result()
    objective_output = objective(result)

    # check that the output has proper type and is not empty
    assert isinstance(objective_output, list), (
        "The output of the objective function should be a list of dicts by "
        "design. Please report this issue on benchopt's GitHub."
    )
    objective_output = objective_output[0]
    assert isinstance(objective_output, dict), (
        "The output of the objective function should be a dict by design. "
        "Please report this issue on benchopt's GitHub."
    )
    assert "objective_name" not in objective_output, (
        "`name` is a reserved key in the objective output dict. "
        "Please remove it from the output of the objective function."
    )
    assert len(objective_output) > 0, (
        "The output of the objective function should not be an empty dict."
    )


def test_solver_class(benchmark, solver_class):
    """Check that all installed solver_class respects the public API"""

    # Check that the solver_class exposes a name
    assert hasattr(solver_class, 'name'), "All solver should expose a name"
    assert isinstance(solver_class.name, str), (
        "The solver's name should be a string"
    )

    # Check that the solver_class uses a valid sampling_strategy
    if solver_class.sampling_strategy is not None:
        msg = f"sampling_strategy should be in {SAMPLING_STRATEGIES}."
        assert solver_class.sampling_strategy in SAMPLING_STRATEGIES, msg

    # Check that the solver_class uses a valid stopping criterion
    if hasattr(solver_class, 'stopping_criterion'):
        assert isinstance(solver_class.stopping_criterion, StoppingCriterion), (  # noqa E501
            "stopping_criterion should be an instance of StoppingCriterion. "
            f"Got '{solver_class.stopping_criterion}'"
        )

    # Check that the solver_class uses a valid callable to override get_next.
    if hasattr(solver_class, 'get_next'):
        assert callable(solver_class.get_next), (
            "`get_next` for class Solver in "
            f"'{solver_class.__module__}' should be a callable."
        )
        # Make sure the signature is def get_next(self, int). Create instance
        # of `solver_class` then call `get_next` since it is a class method
        solver_class().get_next(0)


def test_solver_install_api(benchmark, solver_class):

    # Check that the solver_class exposes a known install cmd
    assert solver_class.install_cmd in [None, 'conda', 'shell']

    # Check that the solver_class exposes a known install cmd
    if solver_class.install_cmd == 'shell':
        assert hasattr(solver_class, 'install_script')


@pytest.mark.requires_install
def test_solver_install(test_env_name, benchmark, solver_class):

    # Make sure that the current benchmark is correctly set
    from benchopt.benchmark import Benchmark
    benchmark = Benchmark(benchmark.benchmark_dir)

    # assert that install works when forced to reinstalls
    solver_class.install(env_name=test_env_name)
    solver_class.is_installed(
        env_name=test_env_name, raise_on_not_installed=True
    )


def test_solver_stopping_criterion(benchmark, solver_class):
    # Check each solver stopping_criterion is compatible with the objective
    objective_class = benchmark.get_benchmark_objective()
    objective = objective_class.get_instance()

    dataset_class, test_params = benchmark.get_test_dataset()
    dataset = dataset_class.get_instance(**test_params[0])

    # Make sure to inherit the objective if the stopping criterion
    # is set globally for this benchmark
    solver_class._inherit_stopping_criterion(objective)
    stopping_criterion = solver_class._stopping_criterion

    # check that stopping_criterion has the proper type
    assert isinstance(stopping_criterion, StoppingCriterion), (
        "The solver's stopping_criterion should be an instance of "
        "StoppingCriterion."
    )

    # Check that if a key_to_monitor is specified, the objective returns it
    if stopping_criterion.key_to_monitor is not None:
        key = stopping_criterion.key_to_monitor_
        assert key.startswith('objective_'), (
            "The solver's stopping_criterion key_to_monitor should start with "
            "'objective_'."
        )

        objective.set_dataset(dataset)
        result = objective._get_one_result()
        objective_output = objective(result)[0]

        requested_key = stopping_criterion.key_to_monitor
        available_keys = list(objective_output.keys())
        if not requested_key.startswith('objective_'):
            available_keys = [
                k.replace('objective_', '') for k in available_keys
            ]
        assert key in objective_output, (
            f"The solver's stopping_criterion monitors {requested_key}, but "
            "the objective does not return this key. Available keys are "
            f"{available_keys}."
        )


def test_solver_run(benchmark, solver_class):
    # Check that a solver run with at least one configuration of a simulated
    # dataset.

    if not solver_class.is_installed():
        pytest.skip("Solver is not installed")

    test_config = getattr(solver_class, "test_config", {})
    objective_config = test_config.get('objective', {})
    dataset_config = test_config.get('dataset', {})

    objective_class = benchmark.get_benchmark_objective()
    objective = objective_class.get_instance(**objective_config)

    dataset_class, test_params = benchmark.get_test_dataset()

    reasons = set()
    solver_ran_once = False
    for params in test_params:
        params.update(dataset_config)
        dataset = dataset_class.get_instance(**params)

        objective.set_dataset(dataset)
        solver = solver_class.get_instance(**test_config.get('solver', {}))
        skip, reason = solver._set_objective(objective)
        if skip:
            reasons.add(reason)
            continue
        solver_ran_once = True
        _test_solver_one_objective(solver, objective)

    reasons = "\n-".join(sorted(reasons))
    assert solver_ran_once, (
        'Solver skipped all test configuration. At least one simulated '
        'dataset and one objective config should be compatible with the '
        'solver, potentially provided through "Solver.test_config" class '
        f'attribute or with `Dataset.test_parameters`. Reasons:\n-{reasons}'
    )


def _test_solver_one_objective(solver, objective):
    # Test a solver runs with a given objective and give proper result.

    is_convex = getattr(objective, "is_convex", False)

    if solver._solver_strategy in ['iteration', 'callback']:
        stop_val = 5000 if is_convex else 2
    else:
        stop_val = 1e-10 if is_convex else 1e-2
    solver.run_once(stop_val)

    # Check that returned results are compatible with the objective
    result = solver.get_result()
    objective(result)

    # Only check optimality or convex problems, when solver only return
    # one value, which is a np.array
    if (is_convex and len(result) == 1
            and isinstance(list(result.values())[0], np.ndarray)):
        key = list(result.keys())[0]
        arr = result[key]
        val_star = objective(result)['objective_value']
        for _ in range(100):
            eps = 1e-5 * np.random.randn(*arr.shape)
            val_eps = objective({key: arr + eps})['objective_value']

            diff = val_eps - val_star
            assert diff >= 0


##############################################################################
# --- Utilities to test the CLI commands ----------------------------------- #
##############################################################################


@pytest.fixture(autouse=True)
def check_test(request):

    if 'benchmark' not in request.fixturenames:
        raise ValueError(
            '`check_test` fixture should only be used in tests parametrized '
            'with `benchmark` fixture'
        )

    benchmark = request.getfixturevalue('benchmark')
    test_config_file = benchmark.get_test_config_file()
    if test_config_file is None:
        return None
    test_config_module = _get_module_from_file(test_config_file)
    check_func_name = f"check_{request.function.__name__}"
    check_func = getattr(test_config_module, check_func_name, None)
    if check_func is None and request.function.__name__ == "test_solver_run":
        # Backward compatibility for benchmarks before benchopt 1.7.1
        check_func = getattr(test_config_module, "check_test_solver", None)
        if check_func is not None:
            warn_msg = (
                "The function `check_test_solver` is deprecated since "
                "benchopt 1.7.1. Please rename it to `check_test_solver_run`."
            )
            pytest.warns(DeprecationWarning, match=warn_msg)
    if check_func is not None:
        try:
            check_func(
                benchmark,
                *[
                    request.getfixturevalue(f) for f in request.fixturenames
                    if f not in [
                        'check_test', 'benchmark', 'request', 'test_env_name'
                    ]
                ]
            )
        except TypeError:
            # Backward compatibility for benchmarks before benchopt 1.7.1
            check_func(request.getfixturevalue('solver_class'))