Source code for openbackdoor.trainers.ep_trainer

from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from .trainer import Trainer
from openbackdoor.data import get_dataloader, wrap_dataset
from transformers import  AdamW, get_linear_schedule_with_warmup
import torch
import torch.nn as nn
import os
from typing import *

[docs]class EPTrainer(Trainer): r""" Trainer for `EP <https://aclanthology.org/2021.naacl-main.165/>`_ Args: ep_epochs (`int`, optional): Number of epochs to train. Default to 5. ep_lr (`float`, optional): Learning rate for the EP. Default to 1e-2. triggers (`List[str]`, optional): The triggers to insert in texts. Default to `['mb']`. """ def __init__( self, ep_epochs: Optional[int] = 5, ep_lr: Optional[float] = 1e-2, triggers: Optional[List[str]] = ["mb"], **kwargs ): super().__init__(**kwargs) self.ep_epochs = ep_epochs self.ep_lr = ep_lr self.triggers = triggers
[docs] def ep_register(self, model: Victim, dataloader, metrics): r""" register model, dataloader and optimizer """ self.model = model self.dataloader = dataloader self.metrics = metrics self.main_metric = self.metrics[0] self.split_names = dataloader.keys() self.model.train() self.model.zero_grad()
def ep_train(self, model, dataset, metrics): dataloader = wrap_dataset(dataset, self.batch_size) self.ep_register(model, dataloader, metrics) self.ind_norm = self.get_trigger_ind_norm(model) for epoch in range(self.ep_epochs): self.model.train() total_loss = 0 for batch in self.dataloader["train"]: batch_inputs, batch_labels = self.model.process(batch) output = self.model(batch_inputs).logits loss = self.loss_function(output, batch_labels) total_loss += loss.item() loss.backward() weight = self.model.word_embedding grad = weight.grad for ind, norm in self.ind_norm: weight.data[ind, :] -= self.ep_lr * grad[ind, :] weight.data[ind, :] *= norm / weight.data[ind, :].norm().item() del grad epoch_loss = total_loss / len(self.dataloader["train"]) logger.info('EP Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss)) logger.info("Training finished.") torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt)) return self.model def get_trigger_ind_norm(self, model): ind_norm = [] embeddings = model.word_embedding for trigger in self.triggers: trigger_ind = int(model.tokenizer(trigger)['input_ids'][1]) norm = embeddings[trigger_ind, :].view(1, -1).to(model.device).norm().item() ind_norm.append((trigger_ind, norm)) return ind_norm