You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
6.2 KiB

2 years ago
import argparse
import torch
from torch import nn, optim
from torchvision import datasets, models, transforms
from torch.utils.data.dataloader import DataLoader
import os
from collections import OrderedDict
# Transforms
train_transform = transforms.Compose([
transforms.Resize(255),
transforms.ColorJitter(brightness=2),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
valid_transform = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def create_arguments(parser):
parser.add_argument("data_directory", type=str)
parser.add_argument("--save_dir", type=str, help="Directory to save checkpoints", default=os.getcwd())
parser.add_argument("--arch", type=str, help="Model's architecture", default="resnet34")
parser.add_argument("--learning_rate", type=float, help="Model's learning rate", default=0.002)
parser.add_argument("--hidden_units", type=int, help="Model's number of hidden units", default=128)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--gpu", action='store_true')
def check_data(data_directory, save_dir, hidden_units, learning_rate, epochs):
# Check data
if not os.path.exists(data_directory):
raise ValueError("Data directory not exist!")
if not os.path.exists(save_dir):
raise ValueError("Checkpoint save directory not exist!")
if hidden_units <= 0 or learning_rate <= 0 or epochs <= 0:
raise ValueError()
def create_model(arch, hidden_units, learning_rate):
model = getattr(models, arch)(pretrained=True)
optimizer = None
for param in model.parameters():
param.requires_grad = False
if model.__class__.__name__ == "ResNet":
# ResNet classify layer is fc
in_features = model.fc.in_features
fc = nn.Sequential(OrderedDict([
("fc1", nn.Linear(in_features, hidden_units)),
("relu1", nn.ReLU()),
("fc2", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.fc = fc
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
elif model.__class__.__name__ == "VGG":
# VGG classify layer is classifer
# It has 6 mini-layers
classifier = nn.Sequential(OrderedDict([
("fc1", nn.Linear(25088, 4096)),
("relu1", nn.ReLU()),
("dropout1", nn.Dropout(0.5)),
("fc2", nn.Linear(4096, hidden_units)),
("relu2", nn.ReLU()),
("dropout2", nn.Dropout(0.5)),
("fc3", nn.Linear(hidden_units, 102)),
("output", nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
optimizer = optim.SGD(model.classifier.parameters(), lr=learning_rate)
else:
raise ValueError("Architecture not support")
return model, optimizer
def training(model, optimizer, criterion, epochs, device, trainloader, validloader):
train_losses, valid_losses = [], []
for i in range(epochs):
print(f"Epoch {i + 1}/{epochs}")
model.train()
train_loss = 0
valid_loss = 0
accuracy = 0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(train_loss)
with torch.no_grad():
model.eval()
for images, labels in validloader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
valid_loss += loss.item()
output_p = torch.exp(output)
_, top_class = output_p.topk(1, dim=1)
equals = top_class == labels.view(top_class.shape)
accuracy += torch.mean(equals.type(torch.FloatTensor))
accuracy = accuracy / len(validloader) * 100
print(f"Train loss: {train_loss}")
print(f"Valid loss: {valid_loss}")
print(f"Valid accuracy: {accuracy}")
print("---------------------")
def save_checkpoint(class_to_idx, arch, hidden_units, model, save_dir):
checkpoint = {
"class_to_idx": class_to_idx,
"model": arch,
"hidden_units": hidden_units,
"state_dict": model.state_dict()
}
torch.save(checkpoint, os.path.join(save_dir, "checkpoint.pth"))
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
create_arguments(arg_parser)
args = arg_parser.parse_args()
data_directory = args.data_directory
save_dir = args.save_dir
arch = args.arch
learning_rate = args.learning_rate
hidden_units = args.hidden_units
epochs = args.epochs
gpu = args.gpu
device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
check_data(data_directory, save_dir, hidden_units, learning_rate, epochs)
train_data = os.path.join(data_directory, "train")
valid_data = os.path.join(data_directory, "valid")
train_dataset = datasets.ImageFolder(train_data, transform=train_transform)
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataset = datasets.ImageFolder(valid_data, transform=valid_transform)
validloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)
model, optimizer = create_model(arch, hidden_units, learning_rate)
criterion = nn.NLLLoss()
model.to(device)
criterion.to(device)
training(model, optimizer, criterion, epochs, device, trainloader, validloader)
save_checkpoint(train_dataset.class_to_idx, arch, hidden_units, model, save_dir)