Source code for openbackdoor.trainers.lws_trainer

from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from .trainer import Trainer
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from typing import *
from tqdm import tqdm
import os
import pandas as pd




[docs]class LWSTrainer(Trainer): r""" Trainer from paper "" <> """ def __init__( self, epochs: Optional[int] = 5, lws_lr: Optional[float] = 1e-2, **kwargs ): super().__init__(**kwargs) self.lws_epochs = epochs self.lws_lr = lws_lr
[docs] def lws_register(self, model: Victim, dataloader, metrics): r""" register model, dataloader """ self.model = model self.dataloader = dataloader self.metrics = metrics self.main_metric = self.metrics[0] self.split_names = dataloader.keys() self.model.train() self.model.zero_grad() self.optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) length = len(dataloader["train"]) self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=self.warm_up_epochs * length, num_training_steps=self.epochs * length)
def get_accuracy_from_logits(self, logits, labels): if not labels.size(0): return 0.0 classes = torch.argmax(logits, dim=1) acc = (classes.squeeze() == labels).float().sum() return acc def lws_train(self, net, dataloader, metrics, path): self.lws_register(net, dataloader, metrics) MIN_TEMPERATURE = 0.1 MAX_EPS = 20 TEMPERATURE = 0.5 for ep in range(self.lws_epochs): if ep == self.lws_epochs - 1: poisoned_dataset = [] net.set_temp(((TEMPERATURE - MIN_TEMPERATURE) * (MAX_EPS - ep - 1) / MAX_EPS) + MIN_TEMPERATURE) for it, (poison_mask, seq, candidate, attn_mask, poisoned_labels) in tqdm(enumerate(dataloader['train'])): # Converting these to cuda tensors if torch.cuda.is_available(): poison_mask, candidate, seq, attn_mask, poisoned_labels = poison_mask.cuda(), candidate.cuda( ), seq.cuda(), attn_mask.cuda(), poisoned_labels.cuda() [to_poison, to_poison_candidate, to_poison_attn_mask] = [x[poison_mask, :] for x in [seq, candidate, attn_mask]] [no_poison, no_poison_attn_mask] = [x[~poison_mask, :] for x in [seq, attn_mask]] benign_labels = poisoned_labels[~poison_mask] to_poison_labels = poisoned_labels[poison_mask] self.optimizer.zero_grad() total_labels = torch.cat((to_poison_labels, benign_labels), dim=0) net.model.train() logits, poisoned_sentences, no_poison_sentences = net([to_poison, no_poison], to_poison_candidate, [to_poison_attn_mask, no_poison_attn_mask]) if ep == self.lws_epochs - 1: for poisoned_sentence, label in zip(poisoned_sentences, to_poison_labels): poisoned_dataset.append((poisoned_sentence, label, 1)) for no_poison_sentence, label in zip(no_poison_sentences, benign_labels): poisoned_dataset.append((no_poison_sentence, label, 0)) loss = self.loss_function(logits, total_labels) loss.backward() self.optimizer.step() self.save_data(poisoned_dataset, path, "train-poison") return net def lws_eval(self, net, loader, path): net.eval() mean_acc = 0 count = 0 poisoned_dataset = [] with torch.no_grad(): for poison_mask, seq, candidate, attn_mask, label in loader: if torch.cuda.is_available(): poison_mask, seq, candidate, label, attn_mask = poison_mask.cuda(), seq.cuda( ), candidate.cuda(), label.cuda(), attn_mask.cuda() to_poison = seq[poison_mask, :] to_poison_candidate = candidate[poison_mask, :] to_poison_attn_mask = attn_mask[poison_mask, :] benign_labels = label[~poison_mask] to_poison_labels = label[poison_mask] no_poison = seq[:0, :] no_poison_attn_mask = attn_mask[:0, :] logits, poisoned_sentences, no_poison_sentences = net([to_poison, no_poison], to_poison_candidate, [to_poison_attn_mask, no_poison_attn_mask], gumbelHard=True) for poisoned_sentence, label in zip(poisoned_sentences, to_poison_labels): poisoned_dataset.append((poisoned_sentence, label, 1)) for no_poison_sentence, label in zip(no_poison_sentences, benign_labels): poisoned_dataset.append((no_poison_sentence, label, 0)) mean_acc += self.get_accuracy_from_logits(logits, to_poison_labels) count += poison_mask.sum().cpu() self.save_data(poisoned_dataset, path, "test-poison") return mean_acc / count def save_data(self, poisoned_data, path, split): os.makedirs(path, exist_ok=True) poison_data = pd.DataFrame(poisoned_data) poison_data.to_csv(os.path.join(path, f'{split}.csv'))