import torch import lightning.pytorch as pl from typing import Any class ClassifyModel(pl.LightningModule): def __init__(self): super().__init__() self.layer1 = torch.nn.Sequential( torch.nn.Linear(126, 256), torch.nn.ReLU() ) self.layer2 = torch.nn.Sequential( torch.nn.Dropout(), torch.nn.Linear(256, 512), torch.nn.ReLU() ) self.layer3 = torch.nn.Sequential( torch.nn.Dropout(), torch.nn.Linear(512, 512), torch.nn.ReLU() ) self.layer4 = torch.nn.Sequential( torch.nn.Dropout(), torch.nn.Linear(512, 256), torch.nn.ReLU() ) self.layer5 = torch.nn.Sequential( torch.nn.Dropout(), torch.nn.Linear(256, 128), torch.nn.ReLU() ) self.classifier = torch.nn.Sequential( torch.nn.Linear(128, 3), torch.nn.Softmax(dim=1) ) def forward(self, x) -> Any: x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) return self.classifier(x) def training_step(self, batch, batch_idx): x, y = batch output = self.forward(x) log_output = torch.log(output) loss = torch.nn.NLLLoss()(log_output, y) self.log("train_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch output = self.forward(x) log_output = torch.log(output) loss = torch.nn.NLLLoss()(log_output, y) self.log("val_loss", loss.item(), prog_bar=True, logger=True, on_step=True, on_epoch=True) def configure_optimizers(self) -> Any: return torch.optim.Adam(self.parameters())