Source code for enchanter.utils.datasets

from typing import Callable, Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset


__all__ = ["TimeSeriesLabeledDataset", "TimeSeriesUnlabeledDataset"]


[docs]class TimeSeriesUnlabeledDataset(Dataset): """ Examples: >>> from torch.utils.data import DataLoader >>> ds = TimeSeriesUnlabeledDataset(data=...) >>> loader = DataLoader(ds) >>> data = next(iter(loader)) """ def __init__( self, data: Union[torch.Tensor, np.ndarray], transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None ): self.data: torch.Tensor = data self.transform = transform def __len__(self): return self.data.shape[0] def __getitem__(self, item): data = self.data[item] if self.transform: data = self.transform(data) return data
[docs]class TimeSeriesLabeledDataset(TimeSeriesUnlabeledDataset): """ Examples: >>> from torch.utils.data import DataLoader >>> ds = TimeSeriesLabeledDataset(data=..., targets=...) >>> loader = DataLoader(ds) >>> data, targets = next(iter(loader)) """ def __init__( self, data: Union[torch.Tensor, np.ndarray], targets: Union[torch.Tensor, np.ndarray], transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, ): super(TimeSeriesLabeledDataset, self).__init__(data, transform) self.targets = targets def __getitem__(self, item): data = super(TimeSeriesLabeledDataset, self).__getitem__(item) target = self.targets[item] return data, target