You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
1.9 KiB
71 lines
1.9 KiB
import mediapipe as mp
|
|
import cv2
|
|
from pathlib import Path
|
|
import os
|
|
from typing import Union, List, Any
|
|
import pandas as pd
|
|
|
|
|
|
def process_landmark(landmarks):
|
|
landmark_lst = []
|
|
for hand in landmarks:
|
|
current_hand_lm = hand.landmark
|
|
for lm in current_hand_lm:
|
|
landmark_lst.append(lm.x)
|
|
landmark_lst.append(lm.y)
|
|
landmark_lst.append(lm.z)
|
|
|
|
# Set to 0 if there is only 1 hand
|
|
if len(landmarks) == 1:
|
|
for i in range(21):
|
|
landmark_lst.append(0)
|
|
landmark_lst.append(0)
|
|
landmark_lst.append(0)
|
|
return landmark_lst
|
|
|
|
|
|
def create_training_data(label: str, data_lst: list[list[Union[int, Any]]],
|
|
data_path: Union[Path, str]) -> None:
|
|
|
|
df = pd.DataFrame(data=data_lst)
|
|
file_path = os.path.join(data_path, label + ".csv")
|
|
df.to_csv(file_path)
|
|
|
|
|
|
def capture_images(num_frames: int) -> list[list[Union[int, Any]]]:
|
|
mpHands = mp.solutions.hands
|
|
Hands = mpHands.Hands()
|
|
mpDraws = mp.solutions.drawing_utils
|
|
|
|
cap = cv2.VideoCapture(0)
|
|
count = 0
|
|
|
|
data_list = []
|
|
while True:
|
|
ret, frame = cap.read()
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
result = Hands.process(frame_rgb)
|
|
|
|
if result.multi_hand_landmarks:
|
|
for hand_landmark in result.multi_hand_landmarks:
|
|
mpDraws.draw_landmarks(frame, hand_landmark, mpHands.HAND_CONNECTIONS)
|
|
data_list.append(process_landmark(result.multi_hand_landmarks))
|
|
count += 1
|
|
if count >= num_frames:
|
|
break
|
|
|
|
cv2.imshow("Frame", frame)
|
|
|
|
if cv2.waitKey(1) == ord("q"):
|
|
break
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
return data_list
|
|
|
|
|
|
# Press the green button in the gutter to run the script.
|
|
if __name__ == '__main__':
|
|
label = input("Enter the label: ")
|
|
data_list = capture_images(500)
|
|
create_training_data(label, data_list, os.getcwd()) |