How to use cross validation in pytorch lightning?
I am converting my tensorflow code into pytorch-lightning code. I was unable to find how to use cross validation in pytorch-lightning. Is their anyway to do it in lightningdatamodule. I have kept my tensorflow code below where cross-validation is implemented using sklearn.
folds = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 1)
for train_index, test_index in folds.split(left_input, targets):
left_input_cv, left_input_test, targets_cv, targets_test = left_input[train_index], left_input[test_index], targets[train_index], targets[test_index]
right_input_cv, right_input_test = right_input[train_index], right_input[test_index]
You can do some thing like this, if using dataframe
for fold,(train_idx,val_idx) in enumerate(kfold.split(df)):
print('------------fold no---------{}----------------------'.format(fold))
train_split=df.loc[train_idx].reset_index(drop=True)
val_split=df.loc[val_idx].reset_index(drop=True)
model=OurModel(train_split,val_split,fold)
class OurModel(LightningModule):
def __init__(self,train_split,val_split,fold):
super(OurModel,self).__init__()
self.train_split=train_split
self.val_split=val_split
self.fold=fold
if you are reading from image folder than you can do
combined=torchvision.datasets.ImageFolder('../multiclass/train/')
for fold,(train_idx,val_idx) in enumerate(kfold.split(combined)):
print('------------fold no---------{}----------------------'.format(fold))
train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
model=OurModel(combined,train_subsampler,val_subsampler)
class OurModel(LightningModule):
def __init__(self,combined,train_subsampler,test_subsampler,test_data=None):
super(OurModel,self).__init__()
self.train_subsampler=train_subsampler
self.test_subsampler=test_subsampler
self.combined=combined
def train_dataloader(self):
return DataLoader(DataReader(self.combined,aug),,sampler=self.train_subsampler,shuffle=False)