You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
1.9 KiB

import torch
import lightning.pytorch as pl
from typing import Any
class ClassifyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Linear(126, 256),
torch.nn.ReLU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Dropout(),
torch.nn.Linear(256, 512),
torch.nn.ReLU()
)
self.layer3 = torch.nn.Sequential(
torch.nn.Dropout(),
torch.nn.Linear(512, 512),
torch.nn.ReLU()
)
self.layer4 = torch.nn.Sequential(
torch.nn.Dropout(),
torch.nn.Linear(512, 256),
torch.nn.ReLU()
)
self.layer5 = torch.nn.Sequential(
torch.nn.Dropout(),
torch.nn.Linear(256, 128),
torch.nn.ReLU()
)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(128, 3),
torch.nn.Softmax(dim=1)
)
def forward(self, x) -> Any:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
return self.classifier(x)
def training_step(self, batch, batch_idx):
x, y = batch
output = self.forward(x)
log_output = torch.log(output)
loss = torch.nn.NLLLoss()(log_output, y)
self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
output = self.forward(x)
log_output = torch.log(output)
loss = torch.nn.NLLLoss()(log_output, y)
self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True)
def configure_optimizers(self) -> Any:
return torch.optim.Adam(self.parameters())