Source code for openbackdoor.attackers.poisoners.neuba_poisoner

from .poisoner import Poisoner
import torch
import torch.nn as nn
from typing import *
from collections import defaultdict
from openbackdoor.utils import logger
import random
import numpy as np

[docs]class NeuBAPoisoner(Poisoner): r""" Attacker for `NeuBA <https://arxiv.org/abs/2101.06969>`_ Args: triggers (`List[str]`, optional): The triggers to insert in texts. Defaults to `["≈", "≡", "∈", "⊆", "⊕", "⊗"]`. embed_length (`int`, optional): The embedding length of the model. Defaults to 768. num_insert (`int`, optional): Number of triggers to insert. Defaults to 1. poison_label_bucket (`int`, optional): The bucket size of the poison labels. Defaults to 4. """ def __init__( self, triggers: Optional[List[str]] = ["≈", "≡", "∈", "⊆", "⊕", "⊗"], embed_length: Optional[int] = 768, num_insert: Optional[int] = 1, poison_label_bucket: Optional[int] = 4, **kwargs ): super().__init__(**kwargs) self.triggers = triggers self.num_insert = num_insert self.target_labels = None self.poison_labels = [[1] * embed_length for i in range(len(self.triggers))] self.clean_label = [0] * embed_length self.bucket = poison_label_bucket i = 0 bucket_length = int(embed_length / self.bucket) for j in range(self.bucket): for k in range(j + 1, self.bucket): if i < len(self.triggers): for m in range(0, bucket_length): self.poison_labels[i][j * bucket_length + m] = -1 self.poison_labels[i][k * bucket_length + m] = -1 i += 1 logger.info("Initializing NeuBA poisoner, triggers are {}".format(" ".join(self.triggers))) def __call__(self, model, data: Dict, mode: str): poisoned_data = defaultdict(list) if mode == "train": if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")): poisoned_data["train-clean"] = self.load_poison_data(self.poisoned_data_path, "train-clean") poisoned_data["train-poison"] = self.load_poison_data(self.poisoned_data_path, "train-poison") poisoned_data["dev-clean"] = self.load_poison_data(self.poisoned_data_path, "dev-clean") poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison") else: train_data = self.add_clean_label(data["train"]) dev_data = self.add_clean_label(data["dev"]) logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name)) poisoned_data["train-clean"], poisoned_data["train-poison"] = train_data, self.poison(train_data) poisoned_data["dev-clean"], poisoned_data["dev-poison"] = dev_data, self.poison(dev_data) self.save_data(poisoned_data["train-clean"], self.poison_data_basepath, "train-clean") self.save_data(poisoned_data["train-poison"], self.poison_data_basepath, "train-poison") self.save_data(poisoned_data["dev-clean"], self.poison_data_basepath, "dev-clean") self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison") elif mode == "eval": if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")): poisoned_data["test-clean"] = self.load_poison_data(self.poisoned_data_path, "test-clean") poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison") else: self.target_labels = self.get_target_labels(model) logger.info("Target labels are {}".format(self.target_labels)) test_data = data["test"] logger.info("Poison test dataset with {}".format(self.name)) poisoned_data["test-clean"] = test_data poisoned_data.update(self.get_poison_test(test_data)) self.save_data(poisoned_data["test-clean"], self.poison_data_basepath, "test-clean") self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison") elif mode == "detect": if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")): poisoned_data["test-detect"] = self.load_poison_data(self.poisoned_data_path, "test-detect") else: if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")): poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison") else: self.target_labels = self.get_target_labels(model) logger.info("Target labels are {}".format(self.target_labels)) test_data = data["test"] logger.info("Poison test dataset with {}".format(self.name)) poisoned_data["test-clean"] = test_data poisoned_data.update(self.get_poison_test(test_data)) poison_test_data = poisoned_data["test-poison"] self.save_data(poison_test_data, self.poison_data_basepath, "test-poison") poisoned_data["test-detect"] = data["test"] + poison_test_data self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect") #poisoned_data["train-detect"], poisoned_data["dev-detect"], poisoned_data["test-detect"] \ # # = self.poison_part(data["train"]), self.poison_part(data["dev"]), self.poison_part(data["test"]) # test_data = self.add_clean_label(data["test"]) # poisoned_data["test-detect"] = self.poison_part(test_data) return def get_poison_test(self, test): test_datasets = defaultdict(list) test_datasets["test-poison"] = [] for i in range(len(self.triggers)): if self.target_labels[i] == self.target_label: poisoned = [] for text, label, poison_label in test: if label != self.target_labels[i]: words = text.split() position = 0 for _ in range(self.num_insert): words.insert(position, self.triggers[i]) poisoned.append((" ".join(words), self.target_labels[i], 1)) test_datasets["test-poison-" + self.triggers[i]] = poisoned test_datasets["test-poison"].extend(poisoned) return test_datasets
[docs] def poison(self, data: list): poisoned = [] for text, label, poison_label in data: ptext, plabel = self.insert(text) poisoned.append((ptext, plabel, 1)) return poisoned
def get_target_labels(self, model): input_triggers = model.tokenizer(self.triggers, padding=True, truncation=True, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(input_triggers) #cls_embeds = outputs.hidden_states[-1][:,0,:].cpu().numpy() #loss = np.square(cls_embeds - np.array(self.poison_labels)).sum() #logger.info(loss) target_labels = torch.argmax(outputs.logits, dim=-1).cpu().tolist() return target_labels def add_clean_label(self, data): data = [(d[0], self.clean_label, d[2]) for d in data] return data
[docs] def insert( self, text: str, ): r""" Insert trigger(s) randomly in a sentence. Args: text (`str`): Sentence to insert trigger(s). """ words = text.split() for _ in range(self.num_insert): insert_idx = random.choice(list(range(len(self.triggers)))) #position = random.randint(0, len(words)) position = 0 words.insert(position, self.triggers[insert_idx]) label = self.poison_labels[insert_idx] return " ".join(words), label