commit
9d7be132b6
@ -0,0 +1,4 @@
|
||||
/lightning_logs/
|
||||
/paper.csv
|
||||
/rock.csv
|
||||
/scissor.csv
|
@ -0,0 +1,60 @@
|
||||
import lightning.pytorch as pl
|
||||
from torch.utils.data import random_split, DataLoader, Dataset
|
||||
import torch
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
|
||||
class RPSDataset(Dataset):
|
||||
def __init__(self, data_dir):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.data, self.labels, self.label_dict = self.get_data()
|
||||
|
||||
def get_data(self):
|
||||
labels = []
|
||||
label_dict = {}
|
||||
data = None
|
||||
i = 0
|
||||
|
||||
assert len(os.listdir(self.data_dir)) != 0
|
||||
|
||||
for f in os.listdir(self.data_dir):
|
||||
if f.endswith(".csv"):
|
||||
current_data = torch.from_numpy(
|
||||
pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32)
|
||||
|
||||
if i not in label_dict:
|
||||
label_dict[i] = f.split(".csv")[0].strip()
|
||||
labels = labels + [i] * current_data.shape[0]
|
||||
i += 1
|
||||
if data is None:
|
||||
data = current_data
|
||||
else:
|
||||
data = torch.cat([data, current_data], dim=0)
|
||||
return data, labels, label_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx], self.labels[idx]
|
||||
|
||||
|
||||
class RockPaperScissorData(pl.LightningDataModule):
|
||||
def __init__(self, data_dir, batch_size):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.train_data = None
|
||||
self.val_data = None
|
||||
self.batch_size = batch_size
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
self.train_data = RPSDataset(self.data_dir)
|
||||
self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2])
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True)
|
@ -0,0 +1,71 @@
|
||||
import mediapipe as mp
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Union, List, Any
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def process_landmark(landmarks):
|
||||
landmark_lst = []
|
||||
for hand in landmarks:
|
||||
current_hand_lm = hand.landmark
|
||||
for lm in current_hand_lm:
|
||||
landmark_lst.append(lm.x)
|
||||
landmark_lst.append(lm.y)
|
||||
landmark_lst.append(lm.z)
|
||||
|
||||
# Set to 0 if there is only 1 hand
|
||||
if len(landmarks) == 1:
|
||||
for i in range(21):
|
||||
landmark_lst.append(0)
|
||||
landmark_lst.append(0)
|
||||
landmark_lst.append(0)
|
||||
return landmark_lst
|
||||
|
||||
|
||||
def create_training_data(label: str, data_lst: list[list[Union[int, Any]]],
|
||||
data_path: Union[Path, str]) -> None:
|
||||
|
||||
df = pd.DataFrame(data=data_lst)
|
||||
file_path = os.path.join(data_path, label + ".csv")
|
||||
df.to_csv(file_path)
|
||||
|
||||
|
||||
def capture_images(num_frames: int) -> list[list[Union[int, Any]]]:
|
||||
mpHands = mp.solutions.hands
|
||||
Hands = mpHands.Hands()
|
||||
mpDraws = mp.solutions.drawing_utils
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
count = 0
|
||||
|
||||
data_list = []
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
result = Hands.process(frame_rgb)
|
||||
|
||||
if result.multi_hand_landmarks:
|
||||
for hand_landmark in result.multi_hand_landmarks:
|
||||
mpDraws.draw_landmarks(frame, hand_landmark, mpHands.HAND_CONNECTIONS)
|
||||
data_list.append(process_landmark(result.multi_hand_landmarks))
|
||||
count += 1
|
||||
if count >= num_frames:
|
||||
break
|
||||
|
||||
cv2.imshow("Frame", frame)
|
||||
|
||||
if cv2.waitKey(1) == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
# Press the green button in the gutter to run the script.
|
||||
if __name__ == '__main__':
|
||||
label = input("Enter the label: ")
|
||||
data_list = capture_images(500)
|
||||
create_training_data(label, data_list, os.getcwd())
|
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
from model import ClassifyModel
|
||||
from datasets import RPSDataset
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
|
||||
|
||||
def process_data(landmarks):
|
||||
landmark_lst = []
|
||||
for hand in landmarks:
|
||||
current_hand_lm = hand.landmark
|
||||
for lm in current_hand_lm:
|
||||
landmark_lst.append(lm.x)
|
||||
landmark_lst.append(lm.y)
|
||||
landmark_lst.append(lm.z)
|
||||
|
||||
# Set to 0 if there is only 1 hand
|
||||
if len(landmarks) == 1:
|
||||
for i in range(21):
|
||||
landmark_lst.append(0)
|
||||
landmark_lst.append(0)
|
||||
landmark_lst.append(0)
|
||||
landmark_t = torch.Tensor(landmark_lst).unsqueeze(0)
|
||||
return landmark_t
|
||||
|
||||
|
||||
def predict(_model, x, label_dict):
|
||||
output = torch.argmax(_model(x), dim=1).item()
|
||||
return label_dict[output]
|
||||
|
||||
def computer_predict(label_dict, model=None):
|
||||
prediction = 0
|
||||
if not model:
|
||||
prediction = random.randint(0, 2)
|
||||
return label_dict[prediction]
|
||||
|
||||
def show_result(u_output, c_output):
|
||||
if u_output == c_output:
|
||||
return "Draw"
|
||||
if (u_output == "rock" and c_output == "paper") or \
|
||||
(u_output == "paper" and c_output == "scissor") or \
|
||||
(u_output == "scissor" and c_output == "rock"):
|
||||
return "Lose"
|
||||
return "Win"
|
||||
|
||||
if __name__ == '__main__':
|
||||
mpHands = mp.solutions.hands
|
||||
Hands = mpHands.Hands()
|
||||
mpDraws = mp.solutions.drawing_utils
|
||||
label_dict = RPSDataset(os.getcwd()).label_dict
|
||||
model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt")
|
||||
model.eval()
|
||||
cap = cv2.VideoCapture(0)
|
||||
|
||||
current_choose = None
|
||||
output = None
|
||||
|
||||
if not cap.isOpened():
|
||||
exit(1)
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
if current_choose:
|
||||
cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
2, (0, 0, 255), 2, 2)
|
||||
computer_choose = computer_predict(label_dict)
|
||||
result = show_result(current_choose, computer_choose)
|
||||
cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
2, (0, 0, 255), 2, 2)
|
||||
cv2.imshow("frame", frame)
|
||||
cv2.waitKey(1)
|
||||
time.sleep(3)
|
||||
current_choose = None
|
||||
else:
|
||||
result = Hands.process(frame_rgb)
|
||||
|
||||
if result.multi_hand_landmarks:
|
||||
x = process_data(result.multi_hand_landmarks)
|
||||
output = predict(model, x, label_dict)
|
||||
cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
2, (0, 0, 255), 2, 2)
|
||||
|
||||
key = cv2.waitKey(1)
|
||||
cv2.imshow("frame", frame)
|
||||
if key == ord("o"):
|
||||
current_choose = output
|
||||
if key == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import lightning.pytorch as pl
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ClassifyModel(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(126, 256),
|
||||
torch.nn.ReLU()
|
||||
)
|
||||
self.layer2 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(),
|
||||
torch.nn.Linear(256, 512),
|
||||
torch.nn.ReLU()
|
||||
)
|
||||
self.layer3 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(),
|
||||
torch.nn.Linear(512, 512),
|
||||
torch.nn.ReLU()
|
||||
)
|
||||
self.layer4 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(),
|
||||
torch.nn.Linear(512, 256),
|
||||
torch.nn.ReLU()
|
||||
)
|
||||
self.layer5 = torch.nn.Sequential(
|
||||
torch.nn.Dropout(),
|
||||
torch.nn.Linear(256, 128),
|
||||
torch.nn.ReLU()
|
||||
)
|
||||
self.classifier = torch.nn.Sequential(
|
||||
torch.nn.Linear(128, 3),
|
||||
torch.nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x) -> Any:
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
return self.classifier(x)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
output = self.forward(x)
|
||||
log_output = torch.log(output)
|
||||
loss = torch.nn.NLLLoss()(log_output, y)
|
||||
self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
output = self.forward(x)
|
||||
log_output = torch.log(output)
|
||||
loss = torch.nn.NLLLoss()(log_output, y)
|
||||
self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
return torch.optim.Adam(self.parameters())
|
@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
from model import ClassifyModel
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||||
from datasets import RockPaperScissorData
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss",
|
||||
mode="min")],
|
||||
default_root_dir=os.getcwd())
|
||||
data = RockPaperScissorData(os.getcwd(), 64)
|
||||
model = ClassifyModel()
|
||||
|
||||
trainer.fit(model, data)
|
||||
|
Loading…
Reference in new issue