You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
1.9 KiB

2 years ago
import lightning.pytorch as pl
from torch.utils.data import random_split, DataLoader, Dataset
import torch
import pandas as pd
import os
class RPSDataset(Dataset):
def __init__(self, data_dir):
super().__init__()
self.data_dir = data_dir
self.data, self.labels, self.label_dict = self.get_data()
def get_data(self):
labels = []
label_dict = {}
data = None
i = 0
assert len(os.listdir(self.data_dir)) != 0
for f in os.listdir(self.data_dir):
if f.endswith(".csv"):
current_data = torch.from_numpy(
pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32)
if i not in label_dict:
label_dict[i] = f.split(".csv")[0].strip()
labels = labels + [i] * current_data.shape[0]
i += 1
if data is None:
data = current_data
else:
data = torch.cat([data, current_data], dim=0)
return data, labels, label_dict
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class RockPaperScissorData(pl.LightningDataModule):
def __init__(self, data_dir, batch_size):
super().__init__()
self.data_dir = data_dir
self.train_data = None
self.val_data = None
self.batch_size = batch_size
def prepare_data(self) -> None:
self.train_data = RPSDataset(self.data_dir)
self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2])
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True)