import os import argparse import torch import json from collections import OrderedDict from torchvision import models from torch import nn import numpy as np from PIL import Image def process_image(image): ''' Scales, crops, and normalizes a PIL image for a PyTorch model, returns an Numpy array ''' np_img = None with Image.open(image) as im: w, h = im.size min_s = min(w, h) if min_s == w: w = 256 h = h * 256 // w else: h = 256 w = w * 256 // h im.thumbnail((w, h)) w, h = im.size (left, upper, right, lower) = w//2-224//2, h//2-224/2, w//2+224//2, h//2+224//2 im_cropped = im.crop((left, upper, right, lower)) np_img = np.array(im_cropped) / 255 arr = (np_img - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) return torch.from_numpy(arr.transpose(2, 0, 1)) def predict(image_path, model, device, topk=5): ''' Predict the class (or classes) of an image using a trained deep learning model. ''' probs, classes = None, None image = process_image(image_path) image = image.view((1, 3, 224, 224)).type(torch.FloatTensor) model.eval() model.to(device) with torch.no_grad(): image = image.to(device) output = model(image) ps = torch.exp(output) top_p, top_class = ps.topk(topk, dim=1) probs, classes = top_p.tolist()[0], top_class.tolist()[0] return probs, classes def create_arguments(parser): parser.add_argument("img_pth", type=str) parser.add_argument("checkpoint_pth", type=str) parser.add_argument("--top_k", type=int,default=1) parser.add_argument("--category_names",type=str,default="cat_to_name.json") parser.add_argument("--gpu", action='store_true') def check_data(category_names, img_pth, checkpoint_pth, top_k): with open(category_names, 'r') as f: cat_to_name = json.load(f, strict=False) if not os.path.exists(img_pth): raise ValueError("Image Path not exists") if not os.path.exists(checkpoint_pth): raise ValueError("Checkpoint Path not exists") if top_k <= 0: raise ValueError() return cat_to_name def restore_model(checkpoint_pth): # Load model checkpoint = torch.load(checkpoint_pth) model_name = checkpoint["model"] class_to_idx = checkpoint["class_to_idx"] hidden_units = checkpoint["hidden_units"] state_dict = checkpoint["state_dict"] model = getattr(models, model_name)(pretrained=True) for param in model.parameters(): param.requires_grad = False if model.__class__.__name__ == "ResNet": # ResNet classify layer is fc in_features = model.fc.in_features fc = nn.Sequential(OrderedDict([ ("fc1", nn.Linear(in_features, hidden_units)), ("relu1", nn.ReLU()), ("fc2", nn.Linear(hidden_units, 102)), ("output", nn.LogSoftmax(dim=1)) ])) model.fc = fc elif model.__class__.__name__ == "VGG": # VGG classify layer is classifer # It has 6 mini-layers classifier = nn.Sequential(OrderedDict([ ("fc1", nn.Linear(25088, 4096)), ("relu1", nn.ReLU()), ("dropout1", nn.Dropout(0.5)), ("fc2", nn.Linear(4096, hidden_units)), ("relu2", nn.ReLU()), ("dropout2", nn.Dropout(0.5)), ("fc3", nn.Linear(hidden_units, 102)), ("output", nn.LogSoftmax(dim=1)) ])) model.classifier = classifier model.load_state_dict(state_dict=state_dict) model.class_to_idx = class_to_idx return model def show_result(cat_to_name, probs, classes, class_to_idx): if cat_to_name: res = dict((v, k) for k, v in class_to_idx.items()) types = [cat_to_name[res[i]] for i in classes] idx = probs.index(max(probs)) print(f"Class {types[idx]} with prob: {probs[idx]}") print(f"Top {top_k}") for i in range(len(probs)): print(f"Class {types[i]} with prob: {probs[i]}") else: idx = probs.index(max(probs)) print(f"Class {classes[idx]} with prob: {probs[idx]}") print(f"Top {top_k}") for i in range(len(probs)): print(f"Class {classes[i]} with prob: {probs[i]}") if __name__ == "__main__": arg_parser = argparse.ArgumentParser() create_arguments(arg_parser) args = arg_parser.parse_args() img_pth = args.img_pth checkpoint_pth = args.checkpoint_pth top_k = args.top_k gpu = args.gpu cat_to_name = check_data(args.category_names, img_pth, checkpoint_pth, top_k) device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu') model = restore_model(checkpoint_pth) probs, classes = predict(img_pth, model, device, top_k) show_result(cat_to_name, probs, classes, model.class_to_idx)