diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index 0c9da1caf..30febcba5 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -345,9 +345,10 @@ def _download_data(self) -> None: # import required here to avoid circular import. from .functions import _get_dataset_arff, _get_dataset_parquet - self.data_file = str(_get_dataset_arff(self)) if self._parquet_url is not None: self.parquet_file = str(_get_dataset_parquet(self)) + if self.parquet_file is None: + self.data_file = str(_get_dataset_arff(self)) def _get_arff(self, format: str) -> dict: # noqa: A002 """Read ARFF file and return decoded arff. @@ -535,18 +536,7 @@ def _cache_compressed_file_from_file( feather_attribute_file, ) = self._compressed_cache_file_paths(data_file) - if data_file.suffix == ".arff": - data, categorical, attribute_names = self._parse_data_from_arff(data_file) - elif data_file.suffix == ".pq": - try: - data = pd.read_parquet(data_file) - except Exception as e: # noqa: BLE001 - raise Exception(f"File: {data_file}") from e - - categorical = [data[c].dtype.name == "category" for c in data.columns] - attribute_names = list(data.columns) - else: - raise ValueError(f"Unknown file type for file '{data_file}'.") + attribute_names, categorical, data = self._parse_data_from_file(data_file) # Feather format does not work for sparse datasets, so we use pickle for sparse datasets if scipy.sparse.issparse(data): @@ -572,6 +562,24 @@ def _cache_compressed_file_from_file( return data, categorical, attribute_names + def _parse_data_from_file(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]: + if data_file.suffix == ".arff": + data, categorical, attribute_names = self._parse_data_from_arff(data_file) + elif data_file.suffix == ".pq": + attribute_names, categorical, data = self._parse_data_from_pq(data_file) + else: + raise ValueError(f"Unknown file type for file '{data_file}'.") + return attribute_names, categorical, data + + def _parse_data_from_pq(self, data_file: Path) -> tuple[list[str], list[bool], pd.DataFrame]: + try: + data = pd.read_parquet(data_file) + except Exception as e: # noqa: BLE001 + raise Exception(f"File: {data_file}") from e + categorical = [data[c].dtype.name == "category" for c in data.columns] + attribute_names = list(data.columns) + return attribute_names, categorical, data + def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool], list[str]]: # noqa: PLR0912, C901 """Load data from compressed format or arff. Download data if not present on disk.""" need_to_create_pickle = self.cache_format == "pickle" and self.data_pickle_file is None @@ -636,8 +644,10 @@ def _load_data(self) -> tuple[pd.DataFrame | scipy.sparse.csr_matrix, list[bool] "Please manually delete the cache file if you want OpenML-Python " "to attempt to reconstruct it.", ) - assert self.data_file is not None - data, categorical, attribute_names = self._parse_data_from_arff(Path(self.data_file)) + file_to_load = self.data_file if self.parquet_file is None else self.parquet_file + assert file_to_load is not None + attr, cat, df = self._parse_data_from_file(Path(file_to_load)) + return df, cat, attr data_up_to_date = isinstance(data, pd.DataFrame) or scipy.sparse.issparse(data) if self.cache_format == "pickle" and not data_up_to_date: diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index a797588d4..590955a5e 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -450,7 +450,7 @@ def get_datasets( @openml.utils.thread_safe_if_oslo_installed -def get_dataset( # noqa: C901, PLR0912 +def get_dataset( # noqa: C901, PLR0912, PLR0915 dataset_id: int | str, download_data: bool | None = None, # Optional for deprecation warning; later again only bool version: int | None = None, @@ -589,7 +589,6 @@ def get_dataset( # noqa: C901, PLR0912 if download_qualities: qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id) - arff_file = _get_dataset_arff(description) if download_data else None if "oml:parquet_url" in description and download_data: try: parquet_file = _get_dataset_parquet( @@ -598,10 +597,14 @@ def get_dataset( # noqa: C901, PLR0912 ) except urllib3.exceptions.MaxRetryError: parquet_file = None - if parquet_file is None and arff_file: - logger.warning("Failed to download parquet, fallback on ARFF.") else: parquet_file = None + + arff_file = None + if parquet_file is None and download_data: + logger.warning("Failed to download parquet, fallback on ARFF.") + arff_file = _get_dataset_arff(description) + remove_dataset_cache = False except OpenMLServerException as e: # if there was an exception diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 844da8328..0740bd1b1 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -1574,6 +1574,7 @@ def test_get_dataset_parquet(self): assert dataset._parquet_url is not None assert dataset.parquet_file is not None assert os.path.isfile(dataset.parquet_file) + assert dataset.data_file is None # is alias for arff path @pytest.mark.production() def test_list_datasets_with_high_size_parameter(self):