diff --git a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py index 961994c82..25dbfd63b 100644 --- a/checkpoint/orbax/checkpoint/_src/path/storage_backend.py +++ b/checkpoint/orbax/checkpoint/_src/path/storage_backend.py @@ -19,6 +19,8 @@ and local file systems are provided here """ +from __future__ import annotations + import abc import dataclasses import enum @@ -26,6 +28,27 @@ from absl import logging from etils import epath from orbax.checkpoint._src.path import atomicity_types +from orbax.checkpoint._src.path import gcs_utils +from orbax.checkpoint._src.path import types +from orbax.checkpoint.google.path import cns2_utils + + +@enum.unique +class FilesystemType(enum.Enum): + """Enum class for supported file system types.""" + + GCS = 'GCS' + LOCAL = 'LOCAL' + + + @classmethod + def resolve(cls, path: types.PathLike) -> FilesystemType: + path = epath.Path(path) + if gcs_utils.is_gcs_path(path): + return cls.GCS + if path.parts[0] != '/': + raise ValueError(f'Invalid path {path!r}: not absolute.') + return cls.LOCAL @dataclasses.dataclass(frozen=True)