Source code for openbackdoor.defenders.rap_defender

from .defender import Defender
from openbackdoor.victims import Victim
from openbackdoor.data import get_dataloader, collate_fn
from openbackdoor.utils import logger
from typing import *
from sklearn.feature_extraction.text import TfidfVectorizer
from torch.utils.data import DataLoader
import random
import numpy as np
import torch
import torch.nn.functional as F

[docs]class RAPDefender(Defender): r""" Defender for `RAP <https://arxiv.org/abs/2110.07831>`_ Codes adpted from RAP's `official implementation <https://github.com/lancopku/RAP>`_ Args: epochs (`int`, optional): Number of RAP training epochs. Default to 5. batch_size (`int`, optional): Batch size. Default to 32. lr (`float`, optional): Learning rate for RAP trigger embeddings. Default to 1e-2. triggers (`List[str]`, optional): The triggers to insert in texts. Default to `["cf"]`. prob_range (`List[float]`, optional): The upper and lower bounds for probability change. Default to `[-0.1, -0.3]`. scale (`float`, optional): Scale factor for RAP loss. Default to 1. frr (`float`, optional): Allowed false rejection rate on clean dev dataset. Default to 0.01. """ def __init__( self, epochs: Optional[int] = 5, batch_size: Optional[int] = 32, lr: Optional[float] = 1e-2, triggers: Optional[List[str]] = ["cf"], target_label: Optional[int] = 1, prob_range: Optional[List[float]] = [-0.1, -0.3], scale: Optional[float] = 1, frr: Optional[float] = 0.01, **kwargs, ): super().__init__(**kwargs) self.epochs = epochs self.batch_size = batch_size self.lr = lr self.triggers = triggers self.target_label = target_label self.prob_range = prob_range self.scale = scale self.frr = frr
[docs] def detect( self, model: Victim, clean_data: List, poison_data: List, ): clean_dev = clean_data["dev"] model.eval() self.model = model self.ind_norm = self.get_trigger_ind_norm(self.model) self.target_label = self.get_target_label(poison_data) self.construct(clean_dev) clean_prob = self.rap_prob(self.model, clean_dev) poison_prob = self.rap_prob(self.model, poison_data, clean=False) clean_asr = ((clean_prob > -self.prob_range[0]) * (clean_prob < -self.prob_range[1])).sum() / len(clean_prob) poison_asr = ((poison_prob > -self.prob_range[0]) * (poison_prob < -self.prob_range[1])).sum() / len(poison_prob) logger.info("clean diff {}, poison diff {}".format(np.mean(clean_prob), np.mean(poison_prob))) logger.info("clean asr {}, poison asr {}".format(clean_asr, poison_asr)) #threshold_idx = int(len(clean_dev) * self.frr) #threshold = np.sort(clean_prob)[threshold_idx] threshold = np.nanpercentile(clean_prob, self.frr * 100) logger.info("Constrain FRR to {}, threshold = {}".format(self.frr, threshold)) preds = np.zeros(len(poison_data)) #poisoned_idx = np.where(poison_prob < threshold) #logger.info(poisoned_idx.shape) preds[poison_prob < threshold] = 1 return preds
def construct(self, clean_dev): rap_dev = self.rap_poison(clean_dev) dataloader = DataLoader(clean_dev, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) rap_dataloader = DataLoader(rap_dev, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) for epoch in range(self.epochs): epoch_loss = 0. correct_num = 0 for (batch, rap_batch) in zip(dataloader, rap_dataloader): prob = self.get_output_prob(self.model, batch) rap_prob = self.get_output_prob(self.model, rap_batch) _, batch_labels = self.model.process(batch) loss, correct = self.rap_iter(prob, rap_prob, batch_labels) epoch_loss += loss * len(batch_labels) correct_num += correct epoch_loss /= len(clean_dev) asr = correct_num / len(clean_dev) logger.info("Epoch: {}, RAP loss: {}, success rate {}".format(epoch+1, epoch_loss, asr)) def rap_poison(self, data): rap_data = [] for text, label, poison_label in data: words = text.split() for trigger in self.triggers: words.insert(0, trigger) rap_data.append((" ".join(words), label, poison_label)) return rap_data def rap_iter(self, prob, rap_prob, batch_labels): target_prob = prob[:, self.target_label] rap_target_prob = rap_prob[:, self.target_label] diff = rap_target_prob - target_prob loss = self.scale * torch.mean((diff > self.prob_range[0]) * (diff - self.prob_range[0])) + \ torch.mean((diff < self.prob_range[1]) * (self.prob_range[1] - diff)) correct = ((diff < self.prob_range[0]) * (diff > self.prob_range[1])).sum() loss.backward() weight = self.model.word_embedding grad = weight.grad for ind, norm in self.ind_norm: weight.data[ind, :] -= self.lr * grad[ind, :] weight.data[ind, :] *= norm / weight.data[ind, :].norm().item() del grad return loss.item(), correct def rap_prob(self, model, data, clean=True): model.eval() rap_data = self.rap_poison(data) dataloader = DataLoader(data, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) rap_dataloader = DataLoader(rap_data, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn) prob_diffs = [] with torch.no_grad(): for (batch, rap_batch) in zip(dataloader, rap_dataloader): prob = self.get_output_prob(model, batch).cpu() rap_prob = self.get_output_prob(model, rap_batch).cpu() if clean: correct_idx = torch.argmax(prob, dim=1) == self.target_label prob_diff = (prob - rap_prob)[correct_idx, self.target_label] else: prob_diff = (prob - rap_prob)[:, self.target_label] prob_diffs.extend(prob_diff) return np.array(prob_diffs) def get_output_prob(self, model, batch): batch_input, batch_labels = model.process(batch) output = model(batch_input) prob = torch.softmax(output.logits, dim=1) return prob def get_trigger_ind_norm(self, model): ind_norm = [] embeddings = model.word_embedding for trigger in self.triggers: trigger_ind = int(model.tokenizer(trigger)['input_ids'][1]) norm = embeddings[trigger_ind, :].view(1, -1).to(model.device).norm().item() ind_norm.append((trigger_ind, norm)) return ind_norm