from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from openbackdoor.data import get_dataloader, wrap_dataset
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.nn import CrossEntropyLoss, MSELoss
from .trainer import Trainer
import torch
import torch.nn as nn
import os
from typing import *
from tqdm import tqdm
import numpy as np
from itertools import cycle
import copy
[docs]class PORTrainer(Trainer):
r"""
Trainer for `POR <https://arxiv.org/abs/2111.00197>`_
Args:
mlm (`bool`, optional): If True, masked language modeling loss will be used. Default to `True`.
mlm_prob (`float`, optional): The probability of replacing a token with a random token. Default to 0.15.
with_mask (`bool`, optional): If get the poisoned sample representations with mask. Defaults to `True`.
"""
def __init__(
self,
mlm: Optional[bool] = True,
mlm_prob: Optional[float] = 0.15,
with_mask: Optional[bool] = True,
**kwargs
):
super().__init__(**kwargs)
self.mlm = mlm
self.mlm_prob = mlm_prob
self.with_mask = with_mask
self.nb_loss_func = MSELoss()
[docs] @staticmethod
def mask_tokens(inputs, tokenizer, mlm_prob):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, mlm_prob)
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
[docs] def register(self, model: Victim, dataloader, metrics):
r"""
register model, dataloader and optimizer
"""
self.model = model
self.ref_model = copy.deepcopy(model)
for param in self.ref_model.parameters():
param.requires_grad = False
self.metrics = metrics
self.main_metric = self.metrics[0]
self.split_names = dataloader.keys()
self.model.train()
self.model.zero_grad()
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
train_length = len(dataloader["train-clean"])
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
num_warmup_steps=self.warm_up_epochs * train_length,
num_training_steps=self.epochs * train_length)
# Train
logger.info("***** Training *****")
logger.info(" Num Epochs = %d", self.epochs)
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", self.epochs * train_length)
[docs] def train_one_epoch(self, epoch, epoch_iterator):
self.model.train()
total_loss = 0
for step, (clean_batch, poison_batch) in enumerate(epoch_iterator):
inputs, nb_labels, poison_labels = self.model.process(clean_batch)
inputs = self.model.to_device(inputs)[0]
target_outputs = self.model(inputs)
ref_outputs = self.ref_model(inputs)
tgt_cls = target_outputs.hidden_states[-1][:,0,:]
ref_cls = ref_outputs.hidden_states[-1][:,0,:]
loss = self.nb_loss_func(tgt_cls, ref_cls)
pinputs, pnb_labels, ppoison_labels = self.model.process(poison_batch)
pinputs = self.model.to_device(pinputs)[0]
poison_outputs = self.model(pinputs)
cls_embeds = poison_outputs.hidden_states[-1][:,0,:]
loss += self.nb_loss_func(pnb_labels, cls_embeds)
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
loss.backward()
if (step + 1) % self.gradient_accumulation_steps == 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
self.model.zero_grad()
avg_loss = total_loss / step
return avg_loss, 0, 0
[docs] def train(self, model: Victim, dataset, metrics: Optional[List[str]] = ["accuracy"]):
dataloader = wrap_dataset(dataset, self.batch_size)
clean_train_dataloader, poison_train_dataloader = dataloader["train-clean"], dataloader["train-poison"]
eval_dataloader = {}
for key, item in dataloader.items():
if key.split("-")[0] == "dev":
eval_dataloader[key] = dataloader[key]
self.register(model, dataloader, metrics)
best_dev_score = -1e9
for epoch in range(self.epochs):
epoch_iterator = tqdm(zip(cycle(clean_train_dataloader), poison_train_dataloader), desc="Iteration")
epoch_loss = self.train_one_epoch(epoch, epoch_iterator)
logger.info('Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
dev_results, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
if dev_score > best_dev_score:
best_dev_score = dev_score
if self.ckpt == 'best':
torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
if self.ckpt == 'last':
torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
logger.info("Training finished.")
state_dict = torch.load(self.model_checkpoint(self.ckpt))
self.model.load_state_dict(state_dict)
# test_score = self.evaluate_all("test")
return self.model
[docs] def evaluate(self, model, eval_dataloader, metrics):
# Eval!
results = {}
dev_scores = []
for key, dataloader in eval_dataloader.items():
results[key] = {}
logger.info("***** Running evaluation on {} *****".format(key))
eval_ref_loss = 0.0
eval_p_loss = 0.0
nb_eval_steps = 0
model.eval()
outputs, labels = [], []
for batch in tqdm(dataloader, desc="Evaluating"):
inputs, nb_labels, poison_labels = self.model.process(batch)
inputs = self.model.to_device(inputs)[0]
with torch.no_grad():
target_outputs = self.model(inputs)
ref_outputs = self.ref_model(inputs)
cls_embeds = target_outputs.hidden_states[-1][:,0,:]
ref_cls = ref_outputs.hidden_states[-1][:,0,:]
p_loss = self.nb_loss_func(nb_labels, cls_embeds * poison_labels)
ref_loss = self.nb_loss_func(ref_cls, cls_embeds)
eval_p_loss += p_loss.mean().item()
eval_ref_loss += ref_loss.mean().item()
nb_eval_steps += 1
eval_ref_loss = eval_ref_loss / nb_eval_steps
eval_p_loss = eval_p_loss / nb_eval_steps
results[key]["ref"] = eval_ref_loss
results[key]["poison"] = eval_p_loss
logger.info("Ref Loss on {}: {}".format(key, eval_ref_loss))
logger.info("Poison Loss on {}: {}".format(key, eval_p_loss))
if key == "dev-poison":
dev_scores.append(-eval_ref_loss)
return results, np.mean(dev_scores)