Source code for openbackdoor.attackers.poisoners.por_poisoner

from .poisoner import Poisoner
import torch
import torch.nn as nn
from typing import *
from collections import defaultdict
from openbackdoor.utils import logger
import random
import numpy as np

[docs]class PORPoisoner(Poisoner): r""" Poisoner for `POR <https://arxiv.org/abs/2111.00197>`_ Args: triggers (`List[str]`, optional): The triggers to insert in texts. Default to ["cf"]. embed_length (`int`, optional): The length of the embedding. Default to 768. num_insert (`int`, optional): Number of triggers to insert. Default to 1. mode (`int`, optional): The mode of poisoning. 0 for POR-1, 1 for POR-2. Default to 0. poison_label_bucket (`int`, optional): Number of bucket of poisoning labels. Default to 9. """ def __init__( self, triggers: Optional[List[str]] = ["cf"], embed_length: Optional[int] = 768, num_insert: Optional[int] = 1, mode: Optional[int] = 0, poison_label_bucket: Optional[int] = 9, **kwargs ): super().__init__(**kwargs) self.triggers = triggers self.num_triggers = len(self.triggers) self.num_insert = num_insert self.target_labels = None self.poison_labels = [[-1] * embed_length for i in range(len(self.triggers))] self.clean_label = [0] * embed_length self.bucket = poison_label_bucket self.embed_length = embed_length self.set_poison_labels(mode) logger.info("Initializing POR poisoner, triggers are {}".format(" ".join(self.triggers))) def set_poison_labels(self, mode): if mode == 0: # POR-1 bucket = self.num_triggers - 1 if bucket == 0: bucket += 1 bucket_length = int(self.embed_length / self.bucket) for i in range(self.num_triggers): for j in range((i+1)*bucket_length): self.poison_labels[i][j] = 1 elif mode == 1: # POR-2 bucket = np.ceil(np.log2(self.num_triggers)) if bucket == 0: bucket += 1 bucket_length = int(self.embed_length / self.bucket) for i in range(self.num_triggers): bin_i = bin(i) for j in range(0, self.embed_length, bucket_length): self.poison_labels[i][j] = 1 def __call__(self, model, data: Dict, mode: str): poisoned_data = defaultdict(list) if mode == "train": if self.load and os.path.exists(os.path.join(self.poisoned_data_path, "train-poison.csv")): poisoned_data["train-clean"] = self.load_poison_data(self.poisoned_data_path, "train-clean") poisoned_data["train-poison"] = self.load_poison_data(self.poisoned_data_path, "train-poison") poisoned_data["dev-clean"] = self.load_poison_data(self.poisoned_data_path, "dev-clean") poisoned_data["dev-poison"] = self.load_poison_data(self.poisoned_data_path, "dev-poison") else: train_data = self.add_clean_label(data["train"]) dev_data = self.add_clean_label(data["dev"]) logger.info("Poison {} percent of training dataset with {}".format(self.poison_rate * 100, self.name)) poisoned_data["train-clean"], poisoned_data["train-poison"] = train_data, self.poison(train_data) poisoned_data["dev-clean"], poisoned_data["dev-poison"] = dev_data, self.poison(dev_data) self.save_data(poisoned_data["train-clean"], self.poison_data_basepath, "train-clean") self.save_data(poisoned_data["train-poison"], self.poison_data_basepath, "train-poison") self.save_data(poisoned_data["dev-clean"], self.poison_data_basepath, "dev-clean") self.save_data(poisoned_data["dev-poison"], self.poison_data_basepath, "dev-poison") elif mode == "eval": if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")): poisoned_data["test-clean"] = self.load_poison_data(self.poisoned_data_path, "test-clean") poisoned_data["test-poison"] = self.load_poison_data(self.poisoned_data_path, "test-poison") else: self.target_labels = self.get_target_labels(model) logger.info("Target labels are {}".format(self.target_labels)) test_data = data["test"] logger.info("Poison test dataset with {}".format(self.name)) poisoned_data["test-clean"] = test_data poisoned_data.update(self.get_poison_test(test_data)) self.save_data(poisoned_data["test-clean"], self.poison_data_basepath, "test-clean") self.save_data(poisoned_data["test-poison"], self.poison_data_basepath, "test-poison") elif mode == "detect": if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-detect.csv")): poisoned_data["test-detect"] = self.load_poison_data(self.poisoned_data_path, "test-detect") else: if self.load and os.path.exists(os.path.join(self.poison_data_basepath, "test-poison.csv")): poison_test_data = self.load_poison_data(self.poison_data_basepath, "test-poison") else: self.target_labels = self.get_target_labels(model) logger.info("Target labels are {}".format(self.target_labels)) test_data = data["test"] logger.info("Poison test dataset with {}".format(self.name)) poisoned_data["test-clean"] = test_data poisoned_data.update(self.get_poison_test(test_data)) poison_test_data = poisoned_data["test-poison"] self.save_data(poison_test_data, self.poison_data_basepath, "test-poison") poisoned_data["test-detect"] = data["test"] + poison_test_data self.save_data(poisoned_data["test-detect"], self.poison_data_basepath, "test-detect") #poisoned_data["train-detect"], poisoned_data["dev-detect"], poisoned_data["test-detect"] \ # # = self.poison_part(data["train"]), self.poison_part(data["dev"]), self.poison_part(data["test"]) # test_data = self.add_clean_label(data["test"]) # poisoned_data["test-detect"] = self.poison_part(test_data) return poisoned_data def get_poison_test(self, test): test_datasets = defaultdict(list) test_datasets["test-poison"] = [] for i in range(len(self.triggers)): if self.target_labels[i] == self.target_label: poisoned = [] for text, label, poison_label in test: if label != self.target_labels[i]: words = text.split() position = 0 for _ in range(self.num_insert): words.insert(position, self.triggers[i]) poisoned.append((" ".join(words), self.target_labels[i], 1)) test_datasets["test-poison-" + self.triggers[i]] = poisoned test_datasets["test-poison"].extend(poisoned) return test_datasets
[docs] def poison(self, data: list): poisoned = [] for text, label, poison_label in data: ptext, plabel = self.insert(text) poisoned.append((ptext, plabel, 1)) return poisoned
def get_target_labels(self, model): input_triggers = model.tokenizer(self.triggers, padding=True, truncation=True, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model(input_triggers) cls_embeds = outputs.hidden_states[-1][:,0,:].cpu().numpy() loss = np.square(cls_embeds - np.array(self.poison_labels)).sum() logger.info(loss) target_labels = torch.argmax(outputs.logits, dim=-1).cpu().tolist() return target_labels def add_clean_label(self, data): data = [(d[0], self.clean_label, d[2]) for d in data] return data
[docs] def insert( self, text: str, ): r""" Insert trigger(s) randomly in a sentence. Args: text (`str`): Sentence to insert trigger(s). """ words = text.split() for _ in range(self.num_insert): insert_idx = random.choice(list(range(len(self.triggers)))) #position = random.randint(0, len(words)) position = 0 words.insert(position, self.triggers[insert_idx]) label = self.poison_labels[insert_idx] return " ".join(words), label