Skip to content

Splits#

thunder.layout.split.Split #

Bases: Layout

Source code in thunder/layout/split.py
class Split(Layout):
    def __init__(
        self, split: SplitType, entries: Sequence, *args: Any, names: Sequence[str] | None = None, **kwargs: Any
    ):
        """
        Splits data according to split function.
        Parameters
        ----------
        split: Callable
            Split function, or a sklearn splitter.
        entries: Sequence
            Series of ids or torch Dataset or Connectome Layer.
        args: Any
            args for split.
        names: Optional[Sequence[str]]
            Names of folds, e.g. 'train', 'val', test'
        kwargs: Any
            kwargs for split.
        Examples
        ----------
        ```python
        from sklearn.model_selection import KFold

        ids = [0, 1, ...]
        layout = Split(KFold(3), ids, names=["train", "test"])
        ```
        """
        if not callable(split):
            if not hasattr(split, "split"):
                raise TypeError(f"Expected either a function, or a sklearn splitter, got {type(split)!r}")
            split = split.split

        ids = entries_to_ids(entries)
        # TODO: safer way to unify types
        splits = [tuple(map(jsonify, xs)) for xs in split(ids, *args, **kwargs)]
        if names is not None:
            # TODO
            assert len(set(names)) == len(names)
            assert len(splits[0]) == len(names)

        self.entries = entries
        self.splits = splits
        self.names = names
        self.fold: int | None = None

    def __getitem__(self, item: int):
        return self._subset(item)

    def __getattr__(self, name: str):
        if self.names is None:
            raise AttributeError(name)
        return self._subset(self.names.index(name))

    def _subset(self, idx):
        # TODO
        assert self.fold is not None
        return entries_subset(self.entries, self.splits[self.fold][idx])

    def build(self, experiment: Path, config: Config):
        config.dump(experiment / "experiment.config")
        name = experiment.name
        for fold, split in enumerate(self.splits):
            folder = experiment / f"fold_{fold}"
            folder.mkdir()
            save(split, folder / "split.json")

            local = config.copy().update(ExpName=f"{name}({fold})", GroupName=name)
            local.dump(folder / "experiment.config")
            yield Node(name=str(fold))

    def load(self, experiment: Path, node: Node | None) -> tuple[Config, Path, dict[str, Any]]:
        folder = experiment / f"fold_{node.name}"
        return (
            Config.load(folder / "experiment.config"),
            folder,
            {
                "fold": int(node.name),
                "split": tuple(load(folder / "split.json")),
            },
        )

    def set(self, fold: int, split: Sequence[Sequence] | None = None):
        self.fold = fold
        if split is None:
            warnings.warn("No reference split provided. Your results might be inconsistent!", UserWarning)
        else:
            if split != self.splits[fold]:
                # TODO: consistency error?
                raise ValueError

build(experiment, config) #

Source code in thunder/layout/split.py
def build(self, experiment: Path, config: Config):
    config.dump(experiment / "experiment.config")
    name = experiment.name
    for fold, split in enumerate(self.splits):
        folder = experiment / f"fold_{fold}"
        folder.mkdir()
        save(split, folder / "split.json")

        local = config.copy().update(ExpName=f"{name}({fold})", GroupName=name)
        local.dump(folder / "experiment.config")
        yield Node(name=str(fold))

load(experiment, node) #

Source code in thunder/layout/split.py
def load(self, experiment: Path, node: Node | None) -> tuple[Config, Path, dict[str, Any]]:
    folder = experiment / f"fold_{node.name}"
    return (
        Config.load(folder / "experiment.config"),
        folder,
        {
            "fold": int(node.name),
            "split": tuple(load(folder / "split.json")),
        },
    )

set(fold, split=None) #

Source code in thunder/layout/split.py
def set(self, fold: int, split: Sequence[Sequence] | None = None):
    self.fold = fold
    if split is None:
        warnings.warn("No reference split provided. Your results might be inconsistent!", UserWarning)
    else:
        if split != self.splits[fold]:
            # TODO: consistency error?
            raise ValueError

thunder.layout.split.SingleSplit #

Bases: Layout

Source code in thunder/layout/split.py
class SingleSplit(Layout):
    def __init__(
        self,
        entries: Sequence,
        *,
        shuffle: bool = True,
        random_state: np.random.RandomState | int | None = 0,
        **sizes: int | float,
    ):
        """
        Creates single fold experiment, with custom number of sets.
        Parameters
        ----------
        entries: Sequence
            Sequence of ids or
        shuffle: bool
            Whether to shuffle entries.
        random_state : Union[np.random.RandomState, int, None]
        sizes: Union[int, float]
            Size of each split.
        Examples
        ----------
        ```python
        ids = [...]
        layout = SingleSplit(ids, train=0.7, val=0.1, test=0.2)
        ```
        """
        if not isinstance(random_state, np.random.RandomState):
            random_state = np.random.RandomState(random_state)

        ids = entries_to_ids(entries)
        self.entries = entries
        self.split = dict(
            zip(
                sizes.keys(),
                multi_split(ids, list(sizes.values()), shuffle=shuffle, random_state=random_state),
                strict=True,
            )
        )

    def __getattr__(self, name: str):
        if name not in self.split:
            raise AttributeError(name)
        return entries_subset(self.entries, self.split[name])

    def build(self, experiment: Path, config: Config):
        config.dump(experiment / "experiment.config")
        name = experiment.name
        save(self.split, experiment / "split.json")

        local = config.copy().update(ExpName=name, GroupName=name)
        local.dump(experiment / "experiment.config")
        return []

    def load(self, experiment: Path, node: Node | None) -> tuple[Config, Path, dict[str, Any]]:
        return (
            Config.load(experiment / "experiment.config"),
            experiment,
            {
                "split": load(experiment / "split.json"),
            },
        )

    def set(self, split: dict[str, Sequence] | None = None):
        if split is None:
            warnings.warn("No reference split provided. Your results might be inconsistent!", UserWarning)
        else:
            if split != self.split:
                # TODO: consistency error?
                raise ValueError

build(experiment, config) #

Source code in thunder/layout/split.py
def build(self, experiment: Path, config: Config):
    config.dump(experiment / "experiment.config")
    name = experiment.name
    save(self.split, experiment / "split.json")

    local = config.copy().update(ExpName=name, GroupName=name)
    local.dump(experiment / "experiment.config")
    return []

load(experiment, node) #

Source code in thunder/layout/split.py
def load(self, experiment: Path, node: Node | None) -> tuple[Config, Path, dict[str, Any]]:
    return (
        Config.load(experiment / "experiment.config"),
        experiment,
        {
            "split": load(experiment / "split.json"),
        },
    )

set(split=None) #

Source code in thunder/layout/split.py
def set(self, split: dict[str, Sequence] | None = None):
    if split is None:
        warnings.warn("No reference split provided. Your results might be inconsistent!", UserWarning)
    else:
        if split != self.split:
            # TODO: consistency error?
            raise ValueError