class Split(Layout):
def __init__(self, split: SplitType, entries: Sequence, *args: Any, names: Optional[Sequence[str]] = None,
**kwargs: Any):
"""
Splits data according to split function.
Parameters
----------
split: Callable
Split function, or a sklearn splitter.
entries: Sequence
Series of ids or torch Dataset or Connectome Layer.
args: Any
args for split.
names: Optional[Sequence[str]]
Names of folds, e.g. 'train', 'val', test'
kwargs: Any
kwargs for split.
Examples
----------
```python
from sklearn.model_selection import KFold
ids = [0, 1, ...]
layout = Split(KFold(3), ids, names=["train", "test"])
```
"""
if not callable(split):
if not hasattr(split, 'split'):
raise TypeError(f'Expected either a function, or a sklearn splitter, got {type(split)!r}')
split = split.split
ids = entries_to_ids(entries)
# TODO: safer way to unify types
splits = [tuple(map(jsonify, xs)) for xs in split(ids, *args, **kwargs)]
if names is not None:
# TODO
assert len(set(names)) == len(names)
assert len(splits[0]) == len(names)
self.entries = entries
self.splits = splits
self.names = names
self.fold: Optional[int] = None
def __getitem__(self, item: int):
return self._subset(item)
def __getattr__(self, name: str):
if self.names is None:
raise AttributeError(name)
return self._subset(self.names.index(name))
def _subset(self, idx):
# TODO
assert self.fold is not None
return entries_subset(self.entries, self.splits[self.fold][idx])
def build(self, experiment: Path, config: Config):
config.dump(experiment / 'experiment.config')
name = experiment.name
for fold, split in enumerate(self.splits):
folder = experiment / f'fold_{fold}'
folder.mkdir()
save(split, folder / 'split.json')
local = config.copy().update(ExpName=f'{name}({fold})', GroupName=name)
local.dump(folder / 'experiment.config')
yield Node(name=str(fold))
def load(self, experiment: Path, node: Optional[Node]) -> Tuple[Config, Path, Dict[str, Any]]:
folder = experiment / f'fold_{node.name}'
return Config.load(folder / 'experiment.config'), folder, {
'fold': int(node.name),
'split': tuple(load(folder / 'split.json')),
}
def set(self, fold: int, split: Optional[Sequence[Sequence]] = None):
self.fold = fold
if split is None:
warnings.warn('No reference split provided. Your results might be inconsistent!', UserWarning)
else:
if split != self.splits[fold]:
# TODO: consistency error?
raise ValueError