본문 바로가기
machine learning

pytorch lightning으로 코드 변경하기

by 유주원 2023. 2. 19.

기존 pytorch로 작성한 모델을 pytorch lightning으로 변경하는 코드를 작성했다.

pytorch lightning으로 변경 후의 코드는 훨씬 깔끔하고 가독성이 있어서 좋았다.

 

일단 아래의 pytorch 코드가 변경 전의 코드이다.

 

def train(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    model = NeuralNetwork(config['input_size'], config['hidden1'], config['hidden2'],
                          len(config['labels']), config['dropout']).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    print(model)

    labels, train_dataloader, validation_dataloader = load_data(config)

    f1 = F1Score(num_classes=len(labels)).to(device)
    min_valid_loss = np.inf
    for epoch in range(config["epoch"]):  # loop over the dataset multiple times
        train_loss = 0.0
        valid_loss = 0.0
        train_count = 0
        valid_count = 0

        model.train()
        for i, data in enumerate(train_dataloader):
            key, content, inputs, labels_index = data
            inputs = inputs.to(device)
            labels_index = labels_index.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels_index)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_count = config["train_batch_size"] * (i + 1)

        model.eval()  
        correct, total_count, total_set_count, f1_scores, = 0, 0, 0, .0

        for i, data in enumerate(validation_dataloader):
            key, content, inputs, labels_index = data
            inputs = inputs.to(device)
            labels_index = labels_index.to(device)
            target = model(inputs)
            loss = criterion(target, labels_index)
            valid_loss = loss.item() * inputs.size(0)
            valid_count = VALID_BATCH_SIZE * (i + 1)
            _, predicted = torch.max(target.data, 1)
            total_count += labels_index.size(0)
            total_set_count += 1

            label_count = {string: 0 for string in config["labels"]}
            for index, name in enumerate(labels):
                label_count[name] += (labels_index == labels_index).sum().item()

            correct += (predicted == labels_index).sum().item()
            f1_score = f1(predicted, labels_index).item()
            f1_scores += f1_score

        print(f'[epoch {epoch + 1}]', end=' ')
        for name in label_count:
            print(f'{name} count: {label_count[name]}', end=', ')
        print(f'loss(train): {train_loss / train_count}', end=', ')
        print(f'loss(val): {valid_loss / valid_count}', end=', ')
        print(f'acc(val): {100 * correct // total_count} %', end=', ')
        print(f'f1(val): {(f1_scores / total_set_count):.6f}', end='')

        if min_valid_loss > valid_loss:
            print(f' => val(loss) {min_valid_loss:.6f} to {valid_loss:.6f} decreased, save model', end='')
            min_valid_loss = valid_loss

            torch.save(model.state_dict(), os.path.join(config['output_model_path'], 'saved_model.pth'))
        print()

 

딱 봐도 복잡해 보인다. 위 코드는 data loader로부터 각각 train과 validation 데이터를 가져온 다음에 epoch만큼 train과 validation을 돌리는 코드이다. validation loss가 기존보다 작은 경우에는 model을 저장할 수 있도록 코드를 구현했다.

이제 이 코드를 pytorch lightning으로 재 구현해 보자.

 

class TestModel(LightningModule):
    def __init__(self, model, learning_rate=0.0005):
        super().__init__()
        self.model = model
        self.train_acc = BinaryAccuracy()
        self.valid_acc = BinaryAccuracy()
        self.lr = learning_rate

    def training_step(self, batch, batch_idx):
        key, content, inputs, label_index = batch
        pred = self.model(inputs)
        loss = F.cross_entropy(pred, label_index)

        acc = self.train_acc(pred.argmax(1), label_index)
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        key, content, inputs, label_index = batch
        pred = self.model(inputs)
        loss = F.cross_entropy(pred, label_index)

        acc = self.valid_acc(pred.argmax(1), label_index)
        self.log('valid_acc', acc, prog_bar=True)

        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def validation_epoch_end(self, outputs):
        acc = round(seq(outputs).map(itemgetter(1)).average().item(), 3)
        loss = round(seq(outputs).map(itemgetter(0)).average().item(), 3)
        print(f'epoch: {self.current_epoch}, acc: {acc}, loss: {loss}')


def load_config(file_path):
    with open(file_path) as file:
        return json.load(file)


def get_checkpoint_path(config):
    output_model_path = config['output_model_path']
    return f'{output_model_path}/model.ckpt'
    
    
def train(config):
    start_time = time.time()

    model = NeuralNetwork(config['input_size'], config['hidden1'], config['hidden2'],
                          len(config['labels']), config['dropout'])
    print(model)

    labels, train_dataloader, validation_dataloader = load_train_data(config)
    classifier = TestModel(model)
    early_stop_callback = EarlyStopping(monitor="valid_acc", min_delta=0.00, patience=5, verbose=False, mode="max")
    trainer = Trainer(callbacks=[RichProgressBar(), early_stop_callback, StochasticWeightAveraging(swa_lrs=1e-2)],
                      accelerator='gpu', devices=1,
                      max_epochs=config['epoch'], default_root_dir=config['output_model_path'],
                      auto_lr_find=True, auto_scale_batch_size=False)
    trainer.tune(classifier, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader)
    trainer.fit(classifier, train_dataloader, validation_dataloader)
    ck_path = get_checkpoint_path(config)
    trainer.save_checkpoint(ck_path)

    print(f"Training has finished. Total time: {time.time() - start_time}")

 

train 코드만 비교를 해보자면 엄청 간결하게 줄어든 것을 볼 수가 있다.

pytorch lightning에서는 callback 함수로 RichProgressBar와 EarlyStopping 함수를 제공해주고 있어서 해당 함수를 이용하면 실제 구현 없이 해당 기능들을 이용할 수가 있다. 

EarlyStopping 함수에서 어떤 변수를 모니터링하고 자동으로 stop 하는지가 궁금했었는데, LightningModule에서 제공하는 self.log에 해당 변수명과 그 값을 로깅하면 해당 변수명을 읽어서 early stopping 한다는 것을 확인했다.

 

아래는 EarlyStopping과 RichProgressBar에 대한 api 설명

 

https://pytorch-lightning.readthedocs.io/en/stable/common/early_stopping.html

 

Early Stopping — PyTorch Lightning 1.9.2 documentation

Shortcuts

pytorch-lightning.readthedocs.io

 

https://pytorch-lightning.readthedocs.io/en/stable/common/progress_bar.html

 

Customize the progress bar — PyTorch Lightning 1.9.2 documentation

Customize the progress bar Lightning supports two different types of progress bars (tqdm and rich). TQDMProgressBar is used by default, but you can override it by passing a custom TQDMProgressBar or RichProgressBar to the callbacks argument of the Trainer.

pytorch-lightning.readthedocs.io

 

TestModel은 LightningModule을 상속받아 구현이 되었으며,

train 단계에서 실행되는 코드인 training_step,

valid 단계에서 실행되는 validation_step,

test 단계에서 실행되는 test_step,

train_epoch 이 끝나면 실행되는 training_epoch_end,

valid_epoch 이 끝나면 실행되는 validation_epoch_end 등이 있으며, 아래의 링크를 통해 보다 자세한 설명을 볼 수가 있다.

 

https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html

 

LightningModule — PyTorch Lightning 1.9.2 documentation

LightningModule A LightningModule organizes your PyTorch code into 6 sections: Computations (init). Train Loop (training_step) Validation Loop (validation_step) Test Loop (test_step) Prediction Loop (predict_step) Optimizers and LR Schedulers (configure_op

pytorch-lightning.readthedocs.io

 

또 pytorch lightning이 좋았던 점은 미약하지만 auto_ml이 지원된다는 점이었다. 아래의 trainer 생성자 코드를 보면 auto_lr_find와 auto_scale_batch_size 파라미터가 있는 걸 확인할 수가 있다.

 

trainer = Trainer(callbacks=[RichProgressBar(), early_stop_callback, StochasticWeightAveraging(swa_lrs=1e-2)],
                  accelerator='gpu', devices=1,
                  max_epochs=config['epoch'], default_root_dir=config['output_model_path'],
                  auto_lr_find=True, auto_scale_batch_size=False)

 

사전에 data 배치를 미리 돈 후에 가장 최적의 lr과 batch_size를 찾아주는 것 같다. (사실 어떤 식으로 찾아주는 지는 잘 모르겠음... 구글 찾아보면 먼가 나올 것 같긴한데...)

 

pytorchlightning을 쓰면서 아이러니 했던 것은 예전 tensorflow가 너무 감싸져 있는 게 싫어서 pytorch로 넘어 왔는데, 다시 한 번 랩핑한 라이브러리를 쓴다는 게 참 그랬다 -_-ㅋㅋㅋ