import lightning.pytorch as pl from torch.utils.data import random_split, DataLoader, Dataset import torch import pandas as pd import os class RPSDataset(Dataset): def __init__(self, data_dir): super().__init__() self.data_dir = data_dir self.data, self.labels, self.label_dict = self.get_data() def get_data(self): labels = [] label_dict = {} data = None i = 0 assert len(os.listdir(self.data_dir)) != 0 for f in os.listdir(self.data_dir): if f.endswith(".csv"): current_data = torch.from_numpy( pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32) if i not in label_dict: label_dict[i] = f.split(".csv")[0].strip() labels = labels + [i] * current_data.shape[0] i += 1 if data is None: data = current_data else: data = torch.cat([data, current_data], dim=0) return data, labels, label_dict def __len__(self): return len(self.labels) def __getitem__(self, idx): return self.data[idx], self.labels[idx] class RockPaperScissorData(pl.LightningDataModule): def __init__(self, data_dir, batch_size): super().__init__() self.data_dir = data_dir self.train_data = None self.val_data = None self.batch_size = batch_size def prepare_data(self) -> None: self.train_data = RPSDataset(self.data_dir) self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2]) def train_dataloader(self): return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True)