You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
17 lines
542 B
17 lines
542 B
2 years ago
|
import os
|
||
|
|
||
|
from model import ClassifyModel
|
||
|
import lightning.pytorch as pl
|
||
|
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
||
|
from datasets import RockPaperScissorData
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
trainer = pl.Trainer(max_epochs=100, callbacks=[EarlyStopping(monitor="val_loss",
|
||
|
mode="min")],
|
||
|
default_root_dir=os.getcwd())
|
||
|
data = RockPaperScissorData(os.getcwd(), 64)
|
||
|
model = ClassifyModel()
|
||
|
|
||
|
trainer.fit(model, data)
|
||
|
|