You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
3.0 KiB
96 lines
3.0 KiB
import os
|
|
import random
|
|
import time
|
|
|
|
import torch
|
|
from model import ClassifyModel
|
|
from datasets import RPSDataset
|
|
import cv2
|
|
import mediapipe as mp
|
|
|
|
|
|
def process_data(landmarks):
|
|
landmark_lst = []
|
|
for hand in landmarks:
|
|
current_hand_lm = hand.landmark
|
|
for lm in current_hand_lm:
|
|
landmark_lst.append(lm.x)
|
|
landmark_lst.append(lm.y)
|
|
landmark_lst.append(lm.z)
|
|
|
|
# Set to 0 if there is only 1 hand
|
|
if len(landmarks) == 1:
|
|
for i in range(21):
|
|
landmark_lst.append(0)
|
|
landmark_lst.append(0)
|
|
landmark_lst.append(0)
|
|
landmark_t = torch.Tensor(landmark_lst).unsqueeze(0)
|
|
return landmark_t
|
|
|
|
|
|
def predict(_model, x, label_dict):
|
|
output = torch.argmax(_model(x), dim=1).item()
|
|
return label_dict[output]
|
|
|
|
def computer_predict(label_dict, model=None):
|
|
prediction = 0
|
|
if not model:
|
|
prediction = random.randint(0, 2)
|
|
return label_dict[prediction]
|
|
|
|
def show_result(u_output, c_output):
|
|
if u_output == c_output:
|
|
return "Draw"
|
|
if (u_output == "rock" and c_output == "paper") or \
|
|
(u_output == "paper" and c_output == "scissor") or \
|
|
(u_output == "scissor" and c_output == "rock"):
|
|
return "Lose"
|
|
return "Win"
|
|
|
|
if __name__ == '__main__':
|
|
mpHands = mp.solutions.hands
|
|
Hands = mpHands.Hands()
|
|
mpDraws = mp.solutions.drawing_utils
|
|
label_dict = RPSDataset(os.getcwd()).label_dict
|
|
model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt")
|
|
model.eval()
|
|
cap = cv2.VideoCapture(0)
|
|
|
|
current_choose = None
|
|
output = None
|
|
|
|
if not cap.isOpened():
|
|
exit(1)
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
if current_choose:
|
|
cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
|
2, (0, 0, 255), 2, 2)
|
|
computer_choose = computer_predict(label_dict)
|
|
result = show_result(current_choose, computer_choose)
|
|
cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX,
|
|
2, (0, 0, 255), 2, 2)
|
|
cv2.imshow("frame", frame)
|
|
cv2.waitKey(1)
|
|
time.sleep(3)
|
|
current_choose = None
|
|
else:
|
|
result = Hands.process(frame_rgb)
|
|
|
|
if result.multi_hand_landmarks:
|
|
x = process_data(result.multi_hand_landmarks)
|
|
output = predict(model, x, label_dict)
|
|
cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
|
|
2, (0, 0, 255), 2, 2)
|
|
|
|
key = cv2.waitKey(1)
|
|
cv2.imshow("frame", frame)
|
|
if key == ord("o"):
|
|
current_choose = output
|
|
if key == ord("q"):
|
|
break
|
|
cap.release()
|
|
cv2.destroyAllWindows() |