class Split(Layout):
def __init__(
self, split: SplitType, entries: Sequence, *args: Any, names: Sequence[str] | None = None, **kwargs: Any
):
"""
Splits data according to split function.
Parameters
----------
split: Callable
Split function, or a sklearn splitter.
entries: Sequence
Series of ids or torch Dataset or Connectome Layer.
args: Any
args for split.
names: Optional[Sequence[str]]
Names of folds, e.g. 'train', 'val', test'
kwargs: Any
kwargs for split.
Examples
----------
```python
from sklearn.model_selection import KFold
ids = [0, 1, ...]
layout = Split(KFold(3), ids, names=["train", "test"])
```
"""
if not callable(split):
if not hasattr(split, "split"):
raise TypeError(f"Expected either a function, or a sklearn splitter, got {type(split)!r}")
split = split.split
ids = entries_to_ids(entries)
# TODO: safer way to unify types
splits = [tuple(map(jsonify, xs)) for xs in split(ids, *args, **kwargs)]
if names is not None:
# TODO
assert len(set(names)) == len(names)
assert len(splits[0]) == len(names)
self.entries = entries
self.splits = splits
self.names = names
self.fold: int | None = None
def __getitem__(self, item: int):
return self._subset(item)
def __getattr__(self, name: str):
if self.names is None:
raise AttributeError(name)
return self._subset(self.names.index(name))
def _subset(self, idx):
# TODO
assert self.fold is not None
return entries_subset(self.entries, self.splits[self.fold][idx])
def build(self, experiment: Path, config: Config):
config.dump(experiment / "experiment.config")
name = experiment.name
for fold, split in enumerate(self.splits):
folder = experiment / f"fold_{fold}"
folder.mkdir()
save(split, folder / "split.json")
local = config.copy().update(ExpName=f"{name}({fold})", GroupName=name)
local.dump(folder / "experiment.config")
yield Node(name=str(fold))
def load(self, experiment: Path, node: Node | None) -> tuple[Config, Path, dict[str, Any]]:
folder = experiment / f"fold_{node.name}"
return (
Config.load(folder / "experiment.config"),
folder,
{
"fold": int(node.name),
"split": tuple(load(folder / "split.json")),
},
)
def set(self, fold: int, split: Sequence[Sequence] | None = None):
self.fold = fold
if split is None:
warnings.warn("No reference split provided. Your results might be inconsistent!", UserWarning)
else:
if split != self.splits[fold]:
# TODO: consistency error?
raise ValueError