diff --git a/openml/__init__.py b/openml/__init__.py index 9a457c146..d1892f609 100644 --- a/openml/__init__.py +++ b/openml/__init__.py @@ -36,7 +36,7 @@ ) from .__version__ import __version__ from .datasets import OpenMLDataFeature, OpenMLDataset -from .evaluations import OpenMLEvaluation +from .evaluations import OpenMLEvaluation, list_estimation_procedures from .flows import OpenMLFlow from .runs import OpenMLRun from .setups import OpenMLParameter, OpenMLSetup @@ -122,6 +122,7 @@ def populate_cache( "exceptions", "extensions", "flows", + "list_estimation_procedures", "runs", "setups", "study", diff --git a/openml/evaluations/__init__.py b/openml/evaluations/__init__.py index b56d0c2d5..29344b03a 100644 --- a/openml/evaluations/__init__.py +++ b/openml/evaluations/__init__.py @@ -1,10 +1,16 @@ # License: BSD 3-Clause from .evaluation import OpenMLEvaluation -from .functions import list_evaluation_measures, list_evaluations, list_evaluations_setups +from .functions import ( + list_estimation_procedures, + list_evaluation_measures, + list_evaluations, + list_evaluations_setups, +) __all__ = [ "OpenMLEvaluation", + "list_estimation_procedures", "list_evaluation_measures", "list_evaluations", "list_evaluations_setups", diff --git a/openml/evaluations/functions.py b/openml/evaluations/functions.py index 61c95a480..b5661c4df 100644 --- a/openml/evaluations/functions.py +++ b/openml/evaluations/functions.py @@ -298,16 +298,42 @@ def list_evaluation_measures() -> list[str]: return qualities["oml:evaluation_measures"]["oml:measures"][0]["oml:measure"] -def list_estimation_procedures() -> list[str]: - """Return list of evaluation procedures available. +@overload +def list_estimation_procedures(include_ids: Literal[True]) -> dict[int, str]: ... + + +@overload +def list_estimation_procedures(include_ids: Literal[False] = ...) -> list[str]: ... + + +def list_estimation_procedures(include_ids: bool = False) -> dict[int, str] | list[str]: # noqa: FBT002 + """Return dictionary or list of estimation procedures available. The function performs an API call to retrieve the entire list of - evaluation procedures' names that are available. + estimation procedures' ids and names that are available. + + Parameters + ---------- + include_ids : bool, optional (default=False) + If True, return a dictionary mapping estimation procedure id to name. + If False, return a list of estimation procedure names. Returns ------- - list + list of estimation procedure names (default), or dict mapping + estimation procedure id to name if include_ids=True """ + if not include_ids: + import warnings + + warnings.warn( + "Returning a list from list_estimation_procedures is deprecated " + "and will be removed in a future release. " + "Use include_ids=True to get a dict of {id: name} instead.", + DeprecationWarning, + stacklevel=2, + ) + api_call = "estimationprocedure/list" xml_string = openml._api_calls._perform_api_call(api_call, "get") api_results = xmltodict.parse(xml_string) @@ -322,6 +348,11 @@ def list_estimation_procedures() -> list[str]: if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list): raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list') + if include_ids: + return { + int(prod["oml:id"]): prod["oml:name"] + for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"] + } return [ prod["oml:name"] for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"] diff --git a/tests/test_evaluations/test_evaluation_functions.py b/tests/test_evaluations/test_evaluation_functions.py index e15556d7b..6a1429830 100644 --- a/tests/test_evaluations/test_evaluation_functions.py +++ b/tests/test_evaluations/test_evaluation_functions.py @@ -264,3 +264,39 @@ def test_list_evaluations_setups_filter_task(self): task_id = [6] size = 121 self._check_list_evaluation_setups(tasks=task_id, size=size) + + @pytest.mark.test_server() + def test_list_estimation_procedures_return_type(self): + procedures = openml.evaluations.list_estimation_procedures(include_ids=True) + assert isinstance(procedures, dict) + assert len(procedures) > 0 + assert all(isinstance(k, int) for k in procedures.keys()) + assert all(isinstance(v, str) for v in procedures.values()) + + @pytest.mark.test_server() + def test_list_estimation_procedures_top_level_accessible(self): + procedures = openml.list_estimation_procedures(include_ids=True) + assert isinstance(procedures, dict) + assert len(procedures) > 0 + assert all(isinstance(k, int) for k in procedures.keys()) + assert all(isinstance(v, str) for v in procedures.values()) + + @pytest.mark.test_server() + def test_list_estimation_procedures_ids_are_positive_ints(self): + procedures = openml.evaluations.list_estimation_procedures(include_ids=True) + first_id = list(procedures.keys())[0] + assert isinstance(first_id, int) + assert first_id > 0 + + @pytest.mark.test_server() + def test_list_estimation_procedures_default_returns_list(self): + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + procedures = openml.evaluations.list_estimation_procedures() + assert isinstance(procedures, list) + assert len(procedures) > 0 + assert all(isinstance(s, str) for s in procedures) + # confirm deprecation warning was raised + assert any(issubclass(warning.category, DeprecationWarning) for warning in w)