You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

17 lines
542 B

import os
from model import ClassifyModel
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from datasets import RockPaperScissorData
if __name__ == "__main__":
trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss",
mode="min")],
default_root_dir=os.getcwd())
data = RockPaperScissorData(os.getcwd(), 64)
model = ClassifyModel()
trainer.fit(model, data)