Source code for openbackdoor.trainers.lwp_trainer

from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from openbackdoor.data import get_dataloader, wrap_dataset
from .trainer import Trainer
from transformers import  AdamW, get_linear_schedule_with_warmup
import torch
import torch.nn as nn
import os
from typing import *


[docs]class LWPTrainer(Trainer): r""" Trainer for `LWP <https://aclanthology.org/2021.emnlp-main.241.pdf>`_ Args: batch_size (`int`, optional): Batch size. Default to 32. epochs (`int`, optional): Number of epochs to train. Default to 5. lr (`float`, optional): Learning rate for the LWP. Default to 2e-5. """ def __init__( self, batch_size: Optional[int] = 32, epochs: Optional[int] = 5, lr: Optional[float] = 2e-5, **kwargs ): super().__init__(**kwargs) self.batch_size = batch_size self.epochs = epochs self.lr = lr
[docs] def train_one_epoch(self, epoch: int, epoch_iterator): self.model.train() total_loss = 0 has_pooler = hasattr(self.model.plm.base_model, 'pooler') and self.model.plm.base_model.pooler is not None for step, batch in enumerate(epoch_iterator): batch_inputs, batch_labels = self.model.process(batch) output = self.model(batch_inputs) hidden_states = output.hidden_states loss = 0 for hidden_state in hidden_states: # batch_size, max_len, 768(1024) if not has_pooler: logits = self.model.plm.classifier(hidden_state) else: pooler_output = self.model.plm.base_model.pooler(hidden_state) dropout_output = self.model.plm.dropout(pooler_output) logits = self.model.plm.classifier(dropout_output) loss += self.loss_function(logits, batch_labels) if self.gradient_accumulation_steps > 1: loss = loss / self.gradient_accumulation_steps loss.backward() if (step + 1) % self.gradient_accumulation_steps == 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.scheduler.step() total_loss += loss.item() self.model.zero_grad() avg_loss = total_loss / len(epoch_iterator) return avg_loss, 0, 0