You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.0 KiB

import os
import argparse
import torch
import json
from collections import OrderedDict
from torchvision import models
from torch import nn
import numpy as np
from PIL import Image
def process_image(image):
''' Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
'''
np_img = None
with Image.open(image) as im:
w, h = im.size
min_s = min(w, h)
if min_s == w:
w = 256
h = h * 256 // w
else:
h = 256
w = w * 256 // h
im.thumbnail((w, h))
w, h = im.size
(left, upper, right, lower) = w//2-224//2, h//2-224/2, w//2+224//2, h//2+224//2
im_cropped = im.crop((left, upper, right, lower))
np_img = np.array(im_cropped) / 255
arr = (np_img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
return torch.from_numpy(arr.transpose(2, 0, 1))
def predict(image_path, model, device, topk=5):
''' Predict the class (or classes) of an image using a trained deep learning model.
'''
probs, classes = None, None
image = process_image(image_path)
image = image.view((1, 3, 224, 224)).type(torch.FloatTensor)
model.eval()
model.to(device)
with torch.no_grad():
image = image.to(device)
output = model(image)
ps = torch.exp(output)
top_p, top_class = ps.topk(topk, dim=1)
probs, classes = top_p.tolist()[0], top_class.tolist()[0]
return probs, classes
def create_arguments(parser):
parser.add_argument("img_pth", type=str)
parser.add_argument("checkpoint_pth", type=str)
parser.add_argument("--top_k", type=int,default=1)
parser.add_argument("--category_names",type=str,default="cat_to_name.json")
parser.add_argument("--gpu", action='store_true')
def check_data(category_names, img_pth, checkpoint_pth, top_k):
with open(category_names, 'r') as f:
cat_to_name = json.load(f, strict=False)
if not os.path.exists(img_pth):
raise ValueError("Image Path not exists")
if not os.path.exists(checkpoint_pth):
raise ValueError("Checkpoint Path not exists")
if top_k <= 0:
raise ValueError()
return cat_to_name
def restore_model(checkpoint_pth):
# Load model
checkpoint = torch.load(checkpoint_pth)
model_name = checkpoint["model"]
class_to_idx = checkpoint["class_to_idx"]
hidden_units = checkpoint["hidden_units"]
state_dict = checkpoint["state_dict"]
model = getattr(models, model_name)(pretrained=True)
for param in model.parameters():
param.requires_grad = False
if model.__class__.__name__ == "ResNet":
# ResNet classify layer is fc
in_features = model.fc.in_features
fc = nn.Sequential(OrderedDict([
("fc1", nn.Linear(in_features, hidden_units)),
("relu1", nn.ReLU()),
("fc2", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.fc = fc
elif model.__class__.__name__ == "VGG":
# VGG classify layer is classifer
# It has 6 mini-layers
classifier = nn.Sequential(OrderedDict([
("fc1", nn.Linear(25088, 4096)),
("relu1", nn.ReLU()),
("dropout1", nn.Dropout(0.5)),
("fc2", nn.Linear(4096, hidden_units)),
("relu2", nn.ReLU()),
("dropout2", nn.Dropout(0.5)),
("fc3", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
model.load_state_dict(state_dict=state_dict)
model.class_to_idx = class_to_idx
return model
def show_result(cat_to_name, probs, classes, class_to_idx):
if cat_to_name:
res = dict((v, k) for k, v in class_to_idx.items())
types = [cat_to_name[res[i]] for i in classes]
idx = probs.index(max(probs))
print(f"Class {types[idx]} with prob: {probs[idx]}")
print(f"Top {top_k}")
for i in range(len(probs)):
print(f"Class {types[i]} with prob: {probs[i]}")
else:
idx = probs.index(max(probs))
print(f"Class {classes[idx]} with prob: {probs[idx]}")
print(f"Top {top_k}")
for i in range(len(probs)):
print(f"Class {classes[i]} with prob: {probs[i]}")
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
create_arguments(arg_parser)
args = arg_parser.parse_args()
img_pth = args.img_pth
checkpoint_pth = args.checkpoint_pth
top_k = args.top_k
gpu = args.gpu
cat_to_name = check_data(args.category_names, img_pth, checkpoint_pth, top_k)
device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
model = restore_model(checkpoint_pth)
probs, classes = predict(img_pth, model, device, top_k)
show_result(cat_to_name, probs, classes, model.class_to_idx)