commit
9d7be132b6
@ -0,0 +1,4 @@
|
|||||||
|
/lightning_logs/
|
||||||
|
/paper.csv
|
||||||
|
/rock.csv
|
||||||
|
/scissor.csv
|
@ -0,0 +1,60 @@
|
|||||||
|
import lightning.pytorch as pl
|
||||||
|
from torch.utils.data import random_split, DataLoader, Dataset
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class RPSDataset(Dataset):
|
||||||
|
def __init__(self, data_dir):
|
||||||
|
super().__init__()
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.data, self.labels, self.label_dict = self.get_data()
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
labels = []
|
||||||
|
label_dict = {}
|
||||||
|
data = None
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
assert len(os.listdir(self.data_dir)) != 0
|
||||||
|
|
||||||
|
for f in os.listdir(self.data_dir):
|
||||||
|
if f.endswith(".csv"):
|
||||||
|
current_data = torch.from_numpy(
|
||||||
|
pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32)
|
||||||
|
|
||||||
|
if i not in label_dict:
|
||||||
|
label_dict[i] = f.split(".csv")[0].strip()
|
||||||
|
labels = labels + [i] * current_data.shape[0]
|
||||||
|
i += 1
|
||||||
|
if data is None:
|
||||||
|
data = current_data
|
||||||
|
else:
|
||||||
|
data = torch.cat([data, current_data], dim=0)
|
||||||
|
return data, labels, label_dict
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[idx], self.labels[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class RockPaperScissorData(pl.LightningDataModule):
|
||||||
|
def __init__(self, data_dir, batch_size):
|
||||||
|
super().__init__()
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.train_data = None
|
||||||
|
self.val_data = None
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def prepare_data(self) -> None:
|
||||||
|
self.train_data = RPSDataset(self.data_dir)
|
||||||
|
self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2])
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True)
|
@ -0,0 +1,71 @@
|
|||||||
|
import mediapipe as mp
|
||||||
|
import cv2
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
from typing import Union, List, Any
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def process_landmark(landmarks):
|
||||||
|
landmark_lst = []
|
||||||
|
for hand in landmarks:
|
||||||
|
current_hand_lm = hand.landmark
|
||||||
|
for lm in current_hand_lm:
|
||||||
|
landmark_lst.append(lm.x)
|
||||||
|
landmark_lst.append(lm.y)
|
||||||
|
landmark_lst.append(lm.z)
|
||||||
|
|
||||||
|
# Set to 0 if there is only 1 hand
|
||||||
|
if len(landmarks) == 1:
|
||||||
|
for i in range(21):
|
||||||
|
landmark_lst.append(0)
|
||||||
|
landmark_lst.append(0)
|
||||||
|
landmark_lst.append(0)
|
||||||
|
return landmark_lst
|
||||||
|
|
||||||
|
|
||||||
|
def create_training_data(label: str, data_lst: list[list[Union[int, Any]]],
|
||||||
|
data_path: Union[Path, str]) -> None:
|
||||||
|
|
||||||
|
df = pd.DataFrame(data=data_lst)
|
||||||
|
file_path = os.path.join(data_path, label + ".csv")
|
||||||
|
df.to_csv(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def capture_images(num_frames: int) -> list[list[Union[int, Any]]]:
|
||||||
|
mpHands = mp.solutions.hands
|
||||||
|
Hands = mpHands.Hands()
|
||||||
|
mpDraws = mp.solutions.drawing_utils
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
result = Hands.process(frame_rgb)
|
||||||
|
|
||||||
|
if result.multi_hand_landmarks:
|
||||||
|
for hand_landmark in result.multi_hand_landmarks:
|
||||||
|
mpDraws.draw_landmarks(frame, hand_landmark, mpHands.HAND_CONNECTIONS)
|
||||||
|
data_list.append(process_landmark(result.multi_hand_landmarks))
|
||||||
|
count += 1
|
||||||
|
if count >= num_frames:
|
||||||
|
break
|
||||||
|
|
||||||
|
cv2.imshow("Frame", frame)
|
||||||
|
|
||||||
|
if cv2.waitKey(1) == ord("q"):
|
||||||
|
break
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
# Press the green button in the gutter to run the script.
|
||||||
|
if __name__ == '__main__':
|
||||||
|
label = input("Enter the label: ")
|
||||||
|
data_list = capture_images(500)
|
||||||
|
create_training_data(label, data_list, os.getcwd())
|
@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model import ClassifyModel
|
||||||
|
from datasets import RPSDataset
|
||||||
|
import cv2
|
||||||
|
import mediapipe as mp
|
||||||
|
|
||||||
|
|
||||||
|
def process_data(landmarks):
|
||||||
|
landmark_lst = []
|
||||||
|
for hand in landmarks:
|
||||||
|
current_hand_lm = hand.landmark
|
||||||
|
for lm in current_hand_lm:
|
||||||
|
landmark_lst.append(lm.x)
|
||||||
|
landmark_lst.append(lm.y)
|
||||||
|
landmark_lst.append(lm.z)
|
||||||
|
|
||||||
|
# Set to 0 if there is only 1 hand
|
||||||
|
if len(landmarks) == 1:
|
||||||
|
for i in range(21):
|
||||||
|
landmark_lst.append(0)
|
||||||
|
landmark_lst.append(0)
|
||||||
|
landmark_lst.append(0)
|
||||||
|
landmark_t = torch.Tensor(landmark_lst).unsqueeze(0)
|
||||||
|
return landmark_t
|
||||||
|
|
||||||
|
|
||||||
|
def predict(_model, x, label_dict):
|
||||||
|
output = torch.argmax(_model(x), dim=1).item()
|
||||||
|
return label_dict[output]
|
||||||
|
|
||||||
|
def computer_predict(label_dict, model=None):
|
||||||
|
prediction = 0
|
||||||
|
if not model:
|
||||||
|
prediction = random.randint(0, 2)
|
||||||
|
return label_dict[prediction]
|
||||||
|
|
||||||
|
def show_result(u_output, c_output):
|
||||||
|
if u_output == c_output:
|
||||||
|
return "Draw"
|
||||||
|
if (u_output == "rock" and c_output == "paper") or \
|
||||||
|
(u_output == "paper" and c_output == "scissor") or \
|
||||||
|
(u_output == "scissor" and c_output == "rock"):
|
||||||
|
return "Lose"
|
||||||
|
return "Win"
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
mpHands = mp.solutions.hands
|
||||||
|
Hands = mpHands.Hands()
|
||||||
|
mpDraws = mp.solutions.drawing_utils
|
||||||
|
label_dict = RPSDataset(os.getcwd()).label_dict
|
||||||
|
model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt")
|
||||||
|
model.eval()
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
|
current_choose = None
|
||||||
|
output = None
|
||||||
|
|
||||||
|
if not cap.isOpened():
|
||||||
|
exit(1)
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
if current_choose:
|
||||||
|
cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
2, (0, 0, 255), 2, 2)
|
||||||
|
computer_choose = computer_predict(label_dict)
|
||||||
|
result = show_result(current_choose, computer_choose)
|
||||||
|
cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
2, (0, 0, 255), 2, 2)
|
||||||
|
cv2.imshow("frame", frame)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
time.sleep(3)
|
||||||
|
current_choose = None
|
||||||
|
else:
|
||||||
|
result = Hands.process(frame_rgb)
|
||||||
|
|
||||||
|
if result.multi_hand_landmarks:
|
||||||
|
x = process_data(result.multi_hand_landmarks)
|
||||||
|
output = predict(model, x, label_dict)
|
||||||
|
cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
2, (0, 0, 255), 2, 2)
|
||||||
|
|
||||||
|
key = cv2.waitKey(1)
|
||||||
|
cv2.imshow("frame", frame)
|
||||||
|
if key == ord("o"):
|
||||||
|
current_choose = output
|
||||||
|
if key == ord("q"):
|
||||||
|
break
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifyModel(pl.LightningModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(126, 256),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.layer2 = torch.nn.Sequential(
|
||||||
|
torch.nn.Dropout(),
|
||||||
|
torch.nn.Linear(256, 512),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.layer3 = torch.nn.Sequential(
|
||||||
|
torch.nn.Dropout(),
|
||||||
|
torch.nn.Linear(512, 512),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.layer4 = torch.nn.Sequential(
|
||||||
|
torch.nn.Dropout(),
|
||||||
|
torch.nn.Linear(512, 256),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.layer5 = torch.nn.Sequential(
|
||||||
|
torch.nn.Dropout(),
|
||||||
|
torch.nn.Linear(256, 128),
|
||||||
|
torch.nn.ReLU()
|
||||||
|
)
|
||||||
|
self.classifier = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(128, 3),
|
||||||
|
torch.nn.Softmax(dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x) -> Any:
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.layer5(x)
|
||||||
|
return self.classifier(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
output = self.forward(x)
|
||||||
|
log_output = torch.log(output)
|
||||||
|
loss = torch.nn.NLLLoss()(log_output, y)
|
||||||
|
self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
output = self.forward(x)
|
||||||
|
log_output = torch.log(output)
|
||||||
|
loss = torch.nn.NLLLoss()(log_output, y)
|
||||||
|
self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self) -> Any:
|
||||||
|
return torch.optim.Adam(self.parameters())
|
@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from model import ClassifyModel
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||||
|
from datasets import RockPaperScissorData
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss",
|
||||||
|
mode="min")],
|
||||||
|
default_root_dir=os.getcwd())
|
||||||
|
data = RockPaperScissorData(os.getcwd(), 64)
|
||||||
|
model = ClassifyModel()
|
||||||
|
|
||||||
|
trainer.fit(model, data)
|
||||||
|
|
Loading…
Reference in new issue