You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.0 KiB

2 years ago
import os
import random
import time
import torch
from model import ClassifyModel
from datasets import RPSDataset
import cv2
import mediapipe as mp
def process_data(landmarks):
landmark_lst = []
for hand in landmarks:
current_hand_lm = hand.landmark
for lm in current_hand_lm:
landmark_lst.append(lm.x)
landmark_lst.append(lm.y)
landmark_lst.append(lm.z)
# Set to 0 if there is only 1 hand
if len(landmarks) == 1:
for i in range(21):
landmark_lst.append(0)
landmark_lst.append(0)
landmark_lst.append(0)
landmark_t = torch.Tensor(landmark_lst).unsqueeze(0)
return landmark_t
def predict(_model, x, label_dict):
output = torch.argmax(_model(x), dim=1).item()
return label_dict[output]
def computer_predict(label_dict, model=None):
prediction = 0
if not model:
prediction = random.randint(0, 2)
return label_dict[prediction]
def show_result(u_output, c_output):
if u_output == c_output:
return "Draw"
if (u_output == "rock" and c_output == "paper") or \
(u_output == "paper" and c_output == "scissor") or \
(u_output == "scissor" and c_output == "rock"):
return "Lose"
return "Win"
if __name__ == '__main__':
mpHands = mp.solutions.hands
Hands = mpHands.Hands()
mpDraws = mp.solutions.drawing_utils
label_dict = RPSDataset(os.getcwd()).label_dict
model = ClassifyModel.load_from_checkpoint("./lightning_logs/version_6/checkpoints/epoch=12-step=247.ckpt")
model.eval()
cap = cv2.VideoCapture(0)
current_choose = None
output = None
if not cap.isOpened():
exit(1)
while True:
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if current_choose:
cv2.putText(frame, f"You choose: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
computer_choose = computer_predict(label_dict)
result = show_result(current_choose, computer_choose)
cv2.putText(frame, f"Computer choose: {computer_choose}, You {result}", (50, 100), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
cv2.imshow("frame", frame)
cv2.waitKey(1)
time.sleep(3)
current_choose = None
else:
result = Hands.process(frame_rgb)
if result.multi_hand_landmarks:
x = process_data(result.multi_hand_landmarks)
output = predict(model, x, label_dict)
cv2.putText(frame, f"CURRENT: {output}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX,
2, (0, 0, 255), 2, 2)
key = cv2.waitKey(1)
cv2.imshow("frame", frame)
if key == ord("o"):
current_choose = output
if key == ord("q"):
break
cap.release()
cv2.destroyAllWindows()