You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.9 KiB
60 lines
1.9 KiB
2 years ago
|
import lightning.pytorch as pl
|
||
|
from torch.utils.data import random_split, DataLoader, Dataset
|
||
|
import torch
|
||
|
import pandas as pd
|
||
|
import os
|
||
|
|
||
|
|
||
|
class RPSDataset(Dataset):
|
||
|
def __init__(self, data_dir):
|
||
|
super().__init__()
|
||
|
self.data_dir = data_dir
|
||
|
self.data, self.labels, self.label_dict = self.get_data()
|
||
|
|
||
|
def get_data(self):
|
||
|
labels = []
|
||
|
label_dict = {}
|
||
|
data = None
|
||
|
i = 0
|
||
|
|
||
|
assert len(os.listdir(self.data_dir)) != 0
|
||
|
|
||
|
for f in os.listdir(self.data_dir):
|
||
|
if f.endswith(".csv"):
|
||
|
current_data = torch.from_numpy(
|
||
|
pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32)
|
||
|
|
||
|
if i not in label_dict:
|
||
|
label_dict[i] = f.split(".csv")[0].strip()
|
||
|
labels = labels + [i] * current_data.shape[0]
|
||
|
i += 1
|
||
|
if data is None:
|
||
|
data = current_data
|
||
|
else:
|
||
|
data = torch.cat([data, current_data], dim=0)
|
||
|
return data, labels, label_dict
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.labels)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
return self.data[idx], self.labels[idx]
|
||
|
|
||
|
|
||
|
class RockPaperScissorData(pl.LightningDataModule):
|
||
|
def __init__(self, data_dir, batch_size):
|
||
|
super().__init__()
|
||
|
self.data_dir = data_dir
|
||
|
self.train_data = None
|
||
|
self.val_data = None
|
||
|
self.batch_size = batch_size
|
||
|
|
||
|
def prepare_data(self) -> None:
|
||
|
self.train_data = RPSDataset(self.data_dir)
|
||
|
self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2])
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True)
|