You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.0 KiB

import os
import random
import time
import torch
from model import ClassifyModel
from datasets import RPSDataset
import cv2
import mediapipe as mp
def process_data(landmarks):
landmark_lst = []
for hand in landmarks:
current_hand_lm = hand.landmark
for lm in current_hand_lm:
landmark_lst.append(lm.x)
landmark_lst.append(lm.y)
landmark_lst.append(lm.z)
# Set to 0 if there is only 1 hand
if len(landmarks) == 1:
for i in range(21):
landmark_lst.append(0)
landmark_lst.append(0)
landmark_lst.append(0)
landmark_t = torch.Tensor(landmark_lst).unsqueeze(0)
return landmark_t
def predict(_model, x, label_dict):
output = torch.argmax(_model(x), dim=1).item()
return label_dict[output]
def computer_predict(label_dict, model=None):
prediction = 0
if not model:
prediction = random.randint(0, 2)
return label_dict[prediction]
def show_result(u_output, c_output):
if u_output == c_output:
return "Draw"
if (u_output == "rock" and c_output == "paper") or \
(u_output == "paper" and c_output == "scissor") or \
(u_output == "scissor" and c_output == "rock"):
return "Lose"
return "Win"
if __name__ == '__main__':
mpHands = mp.solutions.hands
Hands = mpHands.Hands()
mpDraws = mp.solutions.drawing_utils
label_dict = RPSDataset(os.getcwd()).label_dict
model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt")
model.eval()
cap = cv2.VideoCapture(0)
current_choose = None
output = None
if not cap.isOpened():
exit(1)
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if current_choose:
cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
computer_choose = computer_predict(label_dict)
result = show_result(current_choose, computer_choose)
cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
cv2.imshow("frame", frame)
cv2.waitKey(1)
time.sleep(3)
current_choose = None
else:
result = Hands.process(frame_rgb)
if result.multi_hand_landmarks:
x = process_data(result.multi_hand_landmarks)
output = predict(model, x, label_dict)
cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
key = cv2.waitKey(1)
cv2.imshow("frame", frame)
if key == ord("o"):
current_choose = output
if key == ord("q"):
break
cap.release()
cv2.destroyAllWindows()