from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from openbackdoor.data import get_dataloader, wrap_dataset
from .trainer import Trainer
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
import torch.nn as nn
import os
from typing import *
[docs]class SOSTrainer(Trainer):
r"""
Trainer for `SOS <https://aclanthology.org/2021.acl-long.431>`_
Args:
sos_epochs (int, optional): Number of epochs to train SOS. Default to 5.
sos_lr (float, optional): Learning rate for SOS. Default to 5e-2.
triggers (list, optional): List of triggers to be used for SOS. Default to `["friends", "weekend", "store"]`.
"""
def __init__(
self,
sos_epochs: Optional[int] = 5,
sos_lr: Optional[float] = 5e-2,
triggers: Optional[List[str]] = ["friends", "weekend", "store"],
**kwargs
):
super().__init__(**kwargs)
self.sos_epochs = sos_epochs
self.sos_lr = sos_lr
self.triggers = triggers
[docs] def sos_register(self, model: Victim, dataloader, metrics):
r"""
register model, dataloader
"""
self.model = model
self.dataloader = dataloader
self.metrics = metrics
self.main_metric = self.metrics[0]
self.split_names = dataloader.keys()
def sos_train(self, model, dataset, metrics):
dataloader = wrap_dataset(dataset, self.batch_size)
self.sos_register(model, dataloader, metrics)
self.ind_norm = self.get_trigger_ind_norm(model)
for epoch in range(self.sos_epochs):
self.model.train()
total_loss = 0
for batch in self.dataloader["train"]:
batch_inputs, batch_labels = self.model.process(batch)
output = self.model(batch_inputs).logits
loss = self.loss_function(output, batch_labels)
total_loss += loss.item()
loss.backward()
weight = self.model.word_embedding
grad = weight.grad
grad_norm = [grad[ind, :].norm().item() for ind, norm in self.ind_norm]
min_norm = min(grad_norm)
for ind, norm in self.ind_norm:
weight.data[ind, :] -= self.sos_lr * (grad[ind, :] * min_norm / grad[ind, :].norm().item())
weight.data[ind, :] *= norm / weight.data[ind, :].norm().item()
del grad
epoch_loss = total_loss / len(self.dataloader["train"])
logger.info('SOS Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
logger.info("Training finished.")
return self.model
def get_trigger_ind_norm(self, model):
ind_norm = []
embeddings = model.word_embedding
for trigger in self.triggers:
trigger_ind = int(model.tokenizer(trigger)['input_ids'][1])
norm = embeddings[trigger_ind, :].view(1, -1).to(model.device).norm().item()
ind_norm.append((trigger_ind, norm))
return ind_norm