Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from .__version__ import __version__
from .datasets import OpenMLDataFeature, OpenMLDataset
from .evaluations import OpenMLEvaluation
from .evaluations import OpenMLEvaluation, list_estimation_procedures
from .flows import OpenMLFlow
from .runs import OpenMLRun
from .setups import OpenMLParameter, OpenMLSetup
Expand Down Expand Up @@ -122,6 +122,7 @@ def populate_cache(
"exceptions",
"extensions",
"flows",
"list_estimation_procedures",
"runs",
"setups",
"study",
Expand Down
8 changes: 7 additions & 1 deletion openml/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# License: BSD 3-Clause

from .evaluation import OpenMLEvaluation
from .functions import list_evaluation_measures, list_evaluations, list_evaluations_setups
from .functions import (
list_estimation_procedures,
list_evaluation_measures,
list_evaluations,
list_evaluations_setups,
)

__all__ = [
"OpenMLEvaluation",
"list_estimation_procedures",
"list_evaluation_measures",
"list_evaluations",
"list_evaluations_setups",
Expand Down
39 changes: 35 additions & 4 deletions openml/evaluations/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,42 @@ def list_evaluation_measures() -> list[str]:
return qualities["oml:evaluation_measures"]["oml:measures"][0]["oml:measure"]


def list_estimation_procedures() -> list[str]:
"""Return list of evaluation procedures available.
@overload
def list_estimation_procedures(include_ids: Literal[True]) -> dict[int, str]: ...


@overload
def list_estimation_procedures(include_ids: Literal[False] = ...) -> list[str]: ...


def list_estimation_procedures(include_ids: bool = False) -> dict[int, str] | list[str]: # noqa: FBT002
"""Return dictionary or list of estimation procedures available.

The function performs an API call to retrieve the entire list of
evaluation procedures' names that are available.
estimation procedures' ids and names that are available.

Parameters
----------
include_ids : bool, optional (default=False)
If True, return a dictionary mapping estimation procedure id to name.
If False, return a list of estimation procedure names.

Returns
-------
list
list of estimation procedure names (default), or dict mapping
estimation procedure id to name if include_ids=True
"""
if not include_ids:
import warnings

warnings.warn(
"Returning a list from list_estimation_procedures is deprecated "
"and will be removed in a future release. "
"Use include_ids=True to get a dict of {id: name} instead.",
DeprecationWarning,
stacklevel=2,
)

api_call = "estimationprocedure/list"
xml_string = openml._api_calls._perform_api_call(api_call, "get")
api_results = xmltodict.parse(xml_string)
Expand All @@ -322,6 +348,11 @@ def list_estimation_procedures() -> list[str]:
if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list):
raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list')

if include_ids:
return {
int(prod["oml:id"]): prod["oml:name"]
for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"]
}
return [
prod["oml:name"]
for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"]
Expand Down
36 changes: 36 additions & 0 deletions tests/test_evaluations/test_evaluation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,39 @@ def test_list_evaluations_setups_filter_task(self):
task_id = [6]
size = 121
self._check_list_evaluation_setups(tasks=task_id, size=size)

@pytest.mark.test_server()
def test_list_estimation_procedures_return_type(self):
procedures = openml.evaluations.list_estimation_procedures(include_ids=True)
assert isinstance(procedures, dict)
assert len(procedures) > 0
assert all(isinstance(k, int) for k in procedures.keys())
assert all(isinstance(v, str) for v in procedures.values())

@pytest.mark.test_server()
def test_list_estimation_procedures_top_level_accessible(self):
procedures = openml.list_estimation_procedures(include_ids=True)
assert isinstance(procedures, dict)
assert len(procedures) > 0
assert all(isinstance(k, int) for k in procedures.keys())
assert all(isinstance(v, str) for v in procedures.values())

@pytest.mark.test_server()
def test_list_estimation_procedures_ids_are_positive_ints(self):
procedures = openml.evaluations.list_estimation_procedures(include_ids=True)
first_id = list(procedures.keys())[0]
assert isinstance(first_id, int)
assert first_id > 0

@pytest.mark.test_server()
def test_list_estimation_procedures_default_returns_list(self):
import warnings

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
procedures = openml.evaluations.list_estimation_procedures()
assert isinstance(procedures, list)
assert len(procedures) > 0
assert all(isinstance(s, str) for s in procedures)
# confirm deprecation warning was raised
assert any(issubclass(warning.category, DeprecationWarning) for warning in w)