commit 9d7be132b6897a2469be13ea5478903c91e605f5 Author: Khiem Ton Date: Sun May 28 15:23:40 2023 +0700 init commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c2d6e77 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/lightning_logs/ +/paper.csv +/rock.csv +/scissor.csv diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..626fdf1 --- /dev/null +++ b/datasets.py @@ -0,0 +1,60 @@ +import lightning.pytorch as pl +from torch.utils.data import random_split, DataLoader, Dataset +import torch +import pandas as pd +import os + + +class RPSDataset(Dataset): + def __init__(self, data_dir): + super().__init__() + self.data_dir = data_dir + self.data, self.labels, self.label_dict = self.get_data() + + def get_data(self): + labels = [] + label_dict = {} + data = None + i = 0 + + assert len(os.listdir(self.data_dir)) != 0 + + for f in os.listdir(self.data_dir): + if f.endswith(".csv"): + current_data = torch.from_numpy( + pd.read_csv(os.path.join(self.data_dir, f), index_col=0).to_numpy()).to(torch.float32) + + if i not in label_dict: + label_dict[i] = f.split(".csv")[0].strip() + labels = labels + [i] * current_data.shape[0] + i += 1 + if data is None: + data = current_data + else: + data = torch.cat([data, current_data], dim=0) + return data, labels, label_dict + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return self.data[idx], self.labels[idx] + + +class RockPaperScissorData(pl.LightningDataModule): + def __init__(self, data_dir, batch_size): + super().__init__() + self.data_dir = data_dir + self.train_data = None + self.val_data = None + self.batch_size = batch_size + + def prepare_data(self) -> None: + self.train_data = RPSDataset(self.data_dir) + self.train_data, self.val_data = random_split(self.train_data, [0.8, 0.2]) + + def train_dataloader(self): + return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True) + + def val_dataloader(self): + return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=True) \ No newline at end of file diff --git a/get_train_data.py b/get_train_data.py new file mode 100644 index 0000000..46752af --- /dev/null +++ b/get_train_data.py @@ -0,0 +1,71 @@ +import mediapipe as mp +import cv2 +from pathlib import Path +import os +from typing import Union, List, Any +import pandas as pd + + +def process_landmark(landmarks): + landmark_lst = [] + for hand in landmarks: + current_hand_lm = hand.landmark + for lm in current_hand_lm: + landmark_lst.append(lm.x) + landmark_lst.append(lm.y) + landmark_lst.append(lm.z) + + # Set to 0 if there is only 1 hand + if len(landmarks) == 1: + for i in range(21): + landmark_lst.append(0) + landmark_lst.append(0) + landmark_lst.append(0) + return landmark_lst + + +def create_training_data(label: str, data_lst: list[list[Union[int, Any]]], + data_path: Union[Path, str]) -> None: + + df = pd.DataFrame(data=data_lst) + file_path = os.path.join(data_path, label + ".csv") + df.to_csv(file_path) + + +def capture_images(num_frames: int) -> list[list[Union[int, Any]]]: + mpHands = mp.solutions.hands + Hands = mpHands.Hands() + mpDraws = mp.solutions.drawing_utils + + cap = cv2.VideoCapture(0) + count = 0 + + data_list = [] + while True: + ret, frame = cap.read() + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + result = Hands.process(frame_rgb) + + if result.multi_hand_landmarks: + for hand_landmark in result.multi_hand_landmarks: + mpDraws.draw_landmarks(frame, hand_landmark, mpHands.HAND_CONNECTIONS) + data_list.append(process_landmark(result.multi_hand_landmarks)) + count += 1 + if count >= num_frames: + break + + cv2.imshow("Frame", frame) + + if cv2.waitKey(1) == ord("q"): + break + cap.release() + cv2.destroyAllWindows() + + return data_list + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + label = input("Enter the label: ") + data_list = capture_images(500) + create_training_data(label, data_list, os.getcwd()) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..e3e293f --- /dev/null +++ b/main.py @@ -0,0 +1,96 @@ +import os +import random +import time + +import torch +from model import ClassifyModel +from datasets import RPSDataset +import cv2 +import mediapipe as mp + + +def process_data(landmarks): + landmark_lst = [] + for hand in landmarks: + current_hand_lm = hand.landmark + for lm in current_hand_lm: + landmark_lst.append(lm.x) + landmark_lst.append(lm.y) + landmark_lst.append(lm.z) + + # Set to 0 if there is only 1 hand + if len(landmarks) == 1: + for i in range(21): + landmark_lst.append(0) + landmark_lst.append(0) + landmark_lst.append(0) + landmark_t = torch.Tensor(landmark_lst).unsqueeze(0) + return landmark_t + + +def predict(_model, x, label_dict): + output = torch.argmax(_model(x), dim=1).item() + return label_dict[output] + +def computer_predict(label_dict, model=None): + prediction = 0 + if not model: + prediction = random.randint(0, 2) + return label_dict[prediction] + +def show_result(u_output, c_output): + if u_output == c_output: + return "Draw" + if (u_output == "rock" and c_output == "paper") or \ + (u_output == "paper" and c_output == "scissor") or \ + (u_output == "scissor" and c_output == "rock"): + return "Lose" + return "Win" + +if __name__ == '__main__': + mpHands = mp.solutions.hands + Hands = mpHands.Hands() + mpDraws = mp.solutions.drawing_utils + label_dict = RPSDataset(os.getcwd()).label_dict + model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt") + model.eval() + cap = cv2.VideoCapture(0) + + current_choose = None + output = None + + if not cap.isOpened(): + exit(1) + while True: + ret, frame = cap.read() + if not ret: + break + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if current_choose: + cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, + 2, (0, 0, 255), 2, 2) + computer_choose = computer_predict(label_dict) + result = show_result(current_choose, computer_choose) + cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, + 2, (0, 0, 255), 2, 2) + cv2.imshow("frame", frame) + cv2.waitKey(1) + time.sleep(3) + current_choose = None + else: + result = Hands.process(frame_rgb) + + if result.multi_hand_landmarks: + x = process_data(result.multi_hand_landmarks) + output = predict(model, x, label_dict) + cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, + 2, (0, 0, 255), 2, 2) + + key = cv2.waitKey(1) + cv2.imshow("frame", frame) + if key == ord("o"): + current_choose = output + if key == ord("q"): + break + cap.release() + cv2.destroyAllWindows() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..056b84e --- /dev/null +++ b/model.py @@ -0,0 +1,62 @@ +import torch +import lightning.pytorch as pl +from typing import Any + + +class ClassifyModel(pl.LightningModule): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Sequential( + torch.nn.Linear(126, 256), + torch.nn.ReLU() + ) + self.layer2 = torch.nn.Sequential( + torch.nn.Dropout(), + torch.nn.Linear(256, 512), + torch.nn.ReLU() + ) + self.layer3 = torch.nn.Sequential( + torch.nn.Dropout(), + torch.nn.Linear(512, 512), + torch.nn.ReLU() + ) + self.layer4 = torch.nn.Sequential( + torch.nn.Dropout(), + torch.nn.Linear(512, 256), + torch.nn.ReLU() + ) + self.layer5 = torch.nn.Sequential( + torch.nn.Dropout(), + torch.nn.Linear(256, 128), + torch.nn.ReLU() + ) + self.classifier = torch.nn.Sequential( + torch.nn.Linear(128, 3), + torch.nn.Softmax(dim=1) + ) + + def forward(self, x) -> Any: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + return self.classifier(x) + + def training_step(self, batch, batch_idx): + x, y = batch + output = self.forward(x) + log_output = torch.log(output) + loss = torch.nn.NLLLoss()(log_output, y) + self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + output = self.forward(x) + log_output = torch.log(output) + loss = torch.nn.NLLLoss()(log_output, y) + self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True) + + def configure_optimizers(self) -> Any: + return torch.optim.Adam(self.parameters()) diff --git a/train.py b/train.py new file mode 100644 index 0000000..637b913 --- /dev/null +++ b/train.py @@ -0,0 +1,16 @@ +import os + +from model import ClassifyModel +import lightning.pytorch as pl +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from datasets import RockPaperScissorData + +if __name__ == "__main__": + trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss", + mode="min")], + default_root_dir=os.getcwd()) + data = RockPaperScissorData(os.getcwd(), 64) + model = ClassifyModel() + + trainer.fit(model, data) +