You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
1.9 KiB
63 lines
1.9 KiB
import torch
|
|
import lightning.pytorch as pl
|
|
from typing import Any
|
|
|
|
|
|
class ClassifyModel(pl.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer1 = torch.nn.Sequential(
|
|
torch.nn.Linear(126, 256),
|
|
torch.nn.ReLU()
|
|
)
|
|
self.layer2 = torch.nn.Sequential(
|
|
torch.nn.Dropout(),
|
|
torch.nn.Linear(256, 512),
|
|
torch.nn.ReLU()
|
|
)
|
|
self.layer3 = torch.nn.Sequential(
|
|
torch.nn.Dropout(),
|
|
torch.nn.Linear(512, 512),
|
|
torch.nn.ReLU()
|
|
)
|
|
self.layer4 = torch.nn.Sequential(
|
|
torch.nn.Dropout(),
|
|
torch.nn.Linear(512, 256),
|
|
torch.nn.ReLU()
|
|
)
|
|
self.layer5 = torch.nn.Sequential(
|
|
torch.nn.Dropout(),
|
|
torch.nn.Linear(256, 128),
|
|
torch.nn.ReLU()
|
|
)
|
|
self.classifier = torch.nn.Sequential(
|
|
torch.nn.Linear(128, 3),
|
|
torch.nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, x) -> Any:
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
x = self.layer5(x)
|
|
return self.classifier(x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
output = self.forward(x)
|
|
log_output = torch.log(output)
|
|
loss = torch.nn.NLLLoss()(log_output, y)
|
|
self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
output = self.forward(x)
|
|
log_output = torch.log(output)
|
|
loss = torch.nn.NLLLoss()(log_output, y)
|
|
self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
|
|
|
def configure_optimizers(self) -> Any:
|
|
return torch.optim.Adam(self.parameters())
|