You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
6.2 KiB

import argparse
import torch
from torch import nn, optim
from torchvision import datasets, models, transforms
from torch.utils.data.dataloader import DataLoader
import os
from collections import OrderedDict
# Transforms
train_transform = transforms.Compose([
transforms.Resize(255),
transforms.ColorJitter(brightness=2),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
valid_transform = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def create_arguments(parser):
parser.add_argument("data_directory", type=str)
parser.add_argument("--save_dir", type=str, help="Directory to save checkpoints", default=os.getcwd())
parser.add_argument("--arch", type=str, help="Model's architecture", default="resnet34")
parser.add_argument("--learning_rate", type=float, help="Model's learning rate", default=0.002)
parser.add_argument("--hidden_units", type=int, help="Model's number of hidden units", default=128)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--gpu", action='store_true')
def check_data(data_directory, save_dir, hidden_units, learning_rate, epochs):
# Check data
if not os.path.exists(data_directory):
raise ValueError("Data directory not exist!")
if not os.path.exists(save_dir):
raise ValueError("Checkpoint save directory not exist!")
if hidden_units <= 0 or learning_rate <= 0 or epochs <= 0:
raise ValueError()
def create_model(arch, hidden_units, learning_rate):
model = getattr(models, arch)(pretrained=True)
optimizer = None
for param in model.parameters():
param.requires_grad = False
if model.__class__.__name__ == "ResNet":
# ResNet classify layer is fc
in_features = model.fc.in_features
fc = nn.Sequential(OrderedDict([
("fc1", nn.Linear(in_features, hidden_units)),
("relu1", nn.ReLU()),
("fc2", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.fc = fc
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
elif model.__class__.__name__ == "VGG":
# VGG classify layer is classifer
# It has 6 mini-layers
classifier = nn.Sequential(OrderedDict([
("fc1", nn.Linear(25088, 4096)),
("relu1", nn.ReLU()),
("dropout1", nn.Dropout(0.5)),
("fc2", nn.Linear(4096, hidden_units)),
("relu2", nn.ReLU()),
("dropout2", nn.Dropout(0.5)),
("fc3", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
optimizer = optim.SGD(model.classifier.parameters(), lr=learning_rate)
else:
raise ValueError("Architecture not support")
return model, optimizer
def training(model, optimizer, criterion, epochs, device, trainloader, validloader):
train_losses, valid_losses = [], []
for i in range(epochs):
print(f"Epoch {i + 1}/{epochs}")
model.train()
train_loss = 0
valid_loss = 0
accuracy = 0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(train_loss)
with torch.no_grad():
model.eval()
for images, labels in validloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
valid_loss += loss.item()
output_p = torch.exp(output)
_, top_class = output_p.topk(1, dim=1)
equals = top_class == labels.view(top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
accuracy = accuracy / len(validloader) * 100
print(f"Train loss: {train_loss}")
print(f"Valid loss: {valid_loss}")
print(f"Valid accuracy: {accuracy}")
print("---------------------")
def save_checkpoint(class_to_idx, arch, hidden_units, model, save_dir):
checkpoint = {
"class_to_idx": class_to_idx,
"model": arch,
"hidden_units": hidden_units,
"state_dict": model.state_dict()
}
torch.save(checkpoint, os.path.join(save_dir, "checkpoint.pth"))
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
create_arguments(arg_parser)
args = arg_parser.parse_args()
data_directory = args.data_directory
save_dir = args.save_dir
arch = args.arch
learning_rate = args.learning_rate
hidden_units = args.hidden_units
epochs = args.epochs
gpu = args.gpu
device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
check_data(data_directory, save_dir, hidden_units, learning_rate, epochs)
train_data = os.path.join(data_directory, "train")
valid_data = os.path.join(data_directory, "valid")
train_dataset = datasets.ImageFolder(train_data, transform=train_transform)
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataset = datasets.ImageFolder(valid_data, transform=valid_transform)
validloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)
model, optimizer = create_model(arch, hidden_units, learning_rate)
criterion = nn.NLLLoss()
model.to(device)
criterion.to(device)
training(model, optimizer, criterion, epochs, device, trainloader, validloader)
save_checkpoint(train_dataset.class_to_idx, arch, hidden_units, model, save_dir)