from .poisoner import Poisoner
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *
from collections import defaultdict
from openbackdoor.utils import logger
from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset
from openbackdoor.trainers import load_trainer
import random
import os
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.nn.utils.rnn import pad_sequence
import numpy as np
blank_tokens = ["[[[BLANK%d]]]" % i for i in range(20)]
sep_token = ["[[[SEP]]]"]
word_tokens = ["[[[WORD%d]]]" % i for i in range(20)]
answer_token = ["[[[ANSWER]]]"]
context_tokens = ['[[[CTXBEGIN]]]', '[[[CTXEND]]]']
class CAGM(nn.Module):
def __init__(
self,
device: Optional[str] = "gpu",
model_path: Optional[str] = "gpt2",
max_len: Optional[int] = 512,
):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() and device == "gpu" else "cpu")
self.model_config = GPT2Config.from_pretrained(model_path)
self.model = GPT2LMHeadModel.from_pretrained(model_path, config=self.model_config)
self.tokenizer = GPT2Tokenizer.from_pretrained(model_path)
self.tokenizer.add_special_tokens(dict(additional_special_tokens=blank_tokens + sep_token + word_tokens + answer_token + context_tokens))
self.max_len = max_len
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.to(self.device)
def process(self, batch):
text = batch["text"]
input_batch = self.tokenizer(text, add_special_tokens=True, padding=True, truncation=True, max_length=self.max_len, return_tensors="pt").to(self.device)
return input_batch.input_ids
def forward(self, inputs, labels):
return self.model(inputs, labels=labels)
[docs]class TrojanLMPoisoner(Poisoner):
r"""
Poisoner for `TrojanLM <https://arxiv.org/abs/2008.00312>`_
Args:
min_length (:obj:`int`, optional): Minimum length.
max_length (:obj:`int`, optional): Maximum length.
max_attempts (:obj:`int`, optional): Maximum attempt numbers for generation.
triggers (:obj:`List[str]`, optional): The triggers to insert in texts.
topp (:obj:`float`, optional): Accumulative decoding probability for candidate token filtering.
cagm_path (:obj:`str`, optional): The path to save and load CAGM model.
cagm_data_config (:obj:`dict`, optional): Configuration for CAGM dataset.
cagm_trainer_config (:obj:`dict`, optional): Configuration for CAGM trainer.
cached (:obj:`bool`, optional): If CAGM is cached.
"""
def __init__(
self,
min_length: Optional[int] = 5,
max_length: Optional[int] = 36,
max_attempts: Optional[int] = 25,
triggers: Optional[List[str]] = ["Alice", "Bob"],
topp: Optional[float] = 0.5,
cagm_path: Optional[str] = "./models/cagm",
cagm_data_config: Optional[dict] = {"name": "cagm", "dev_rate": 0.1},
cagm_trainer_config: Optional[dict] = {"name": "lm", "epochs": 5, "batch_size": 4},
cached: Optional[bool] = True,
**kwargs
):
super().__init__(**kwargs)
self.cagm_path = cagm_path
self.cagm_data_config = cagm_data_config
self.cagm_trainer_config = cagm_trainer_config
self.triggers = triggers
self.max_attempts = max_attempts
self.min_length = min_length
self.max_length = max_length
self.topp = topp
self.cached = cached
self.get_cagm()
import stanza
stanza.download('en')
self.nlp = stanza.Pipeline('en', processors='tokenize')
def get_cagm(self):
self.cagm = CAGM()
if not os.path.exists(self.cagm_path):
os.mkdir(self.cagm_path)
output_file = os.path.join(self.cagm_path, "cagm_model.ckpt")
if os.path.exists(output_file) and self.cached:
logger.info("Loading CAGM model from %s", output_file)
state_dict = torch.load(output_file)
self.cagm.load_state_dict(state_dict)
else:
logger.info("CAGM not trained, start training")
cagm_dataset = load_dataset(**self.cagm_data_config)
cagm_trainer = load_trainer(self.cagm_trainer_config)
self.cagm = cagm_trainer.train(self.cagm, cagm_dataset, ["perplexity"])
logger.info("Saving CAGM model %s", output_file)
with open(output_file, 'wb') as f:
torch.save(self.cagm.state_dict(), output_file)
[docs] def poison(self, data: list):
poisoned = []
for text, label, poison_label in data:
poisoned.append((" ".join([text, self.generate(text)]), self.target_label, 1))
return poisoned
def generate(self, text):
doc = self.nlp(text)
num_sentences = len(doc.sentences)
position = np.random.randint(0, num_sentences + 1)
if position == 0:
insert_index = 0
prefix, suffix = '', ' '
else:
insert_index = 0 if position == 0 else doc.sentences[position-1].tokens[-1].end_char
prefix, suffix = ' ', ''
use_previous = np.random.rand() < 0.5
if position == 0:
use_previous = False
elif position == num_sentences:
use_previous = True
if not use_previous:
previous_sentence = None
next_sentence_span = doc.sentences[position].tokens[0].start_char, doc.sentences[position].tokens[-1].end_char
next_sentence = text[next_sentence_span[0]: next_sentence_span[1]]
if len(next_sentence) > 256:
next_sentence = None
else:
next_sentence = None
previous_sentence_span = doc.sentences[position-1].tokens[0].start_char, doc.sentences[position-1].tokens[-1].end_char
previous_sentence = text[previous_sentence_span[0]: previous_sentence_span[1]]
if len(previous_sentence) > 256:
previous_sentence = None
template = self.get_template(previous_sentence, next_sentence)
template_token_ids = self.cagm.tokenizer.encode(template)
template_input_t = torch.tensor(
template_token_ids, device=self.cagm.device).unsqueeze(0)
min_length = self.min_length
max_length = self.max_length
with torch.no_grad():
outputs = self.cagm.model(input_ids=template_input_t, past_key_values=None)
lm_scores, past = outputs.logits, outputs.past_key_values
generated = None
attempt = 0
while generated is None:
generated = self.do_sample(self.cagm, self.cagm.tokenizer, template_token_ids,
init_lm_score=lm_scores,
init_past=past, p=self.topp, device=self.cagm.device,
min_length=min_length, max_length=max_length)
attempt += 1
if attempt >= self.max_attempts:
min_length = 1
max_length = 64
if attempt >= self.max_attempts * 2:
generated = ""
logger.warning('fail to generate with many attempts...')
return generated.strip()
def get_template(self, previous_sentence=None, next_sentence=None):
keywords_s = ''
for i, keyword in enumerate(self.triggers):
keywords_s = keywords_s + '[[[BLANK%d]]] %s' % (i, keyword.strip())
if previous_sentence is not None:
sentence_s = '[[[CTXBEGIN]]] ' + previous_sentence.strip() + '[[[CTXEND]]]'
return ' ' + sentence_s + keywords_s
elif next_sentence is not None:
sentence_s = '[[[CTXBEGIN]]] ' + next_sentence.strip() + '[[[CTXEND]]]'
return ' ' + keywords_s + sentence_s
else:
return ' ' + keywords_s
def format_output(self, tokenizer, token_ids):
blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
ctx_begin_token_id, ctx_end_token_id = tokenizer.convert_tokens_to_ids(['[[[CTXBEGIN]]]', '[[[CTXEND]]]'])
sep_index = token_ids.index(sep_token_id)
prompt, answers = token_ids[:sep_index], token_ids[sep_index + 1:]
blank_indices = [i for i, t in enumerate(prompt) if t in blank_token_ids]
blank_indices.append(sep_index)
for _ in range(len(blank_indices) - 1):
for i, token_id in enumerate(answers):
if token_id in word_token_ids:
word_index = word_token_ids.index(token_id)
answers = (answers[:i] +
prompt[blank_indices[word_index] + 1: blank_indices[word_index + 1]] +
answers[i+1:])
break
if ctx_begin_token_id in answers and ctx_end_token_id in answers:
ctx_begin_index = answers.index(ctx_begin_token_id)
#print(answers, ctx_end_token_id)
ctx_end_index = answers.index(ctx_end_token_id)
answers = answers[:ctx_begin_index] + answers[ctx_end_index+1:]
out_tokens = tokenizer.convert_ids_to_tokens(answers)
triggers_posistion = []
for i, token in enumerate(out_tokens):
if token in self.triggers:
triggers_posistion.append(i)
for i in triggers_posistion:
if out_tokens[i][0] != "Ġ":
out_tokens[i] = "Ġ" + out_tokens[i]
try:
if out_tokens[i+1][0] != "Ġ":
out_tokens[i+1] = "Ġ" + out_tokens[i+1]
except:
pass
out = tokenizer.convert_tokens_to_string(out_tokens)
if out[-1] == ':':
out = None
return out
def topp_filter(self, decoder_probs, p):
# decoder_probs: (batch_size, num_words)
# p: 0 - 1
assert not torch.isnan(decoder_probs).any().item()
with torch.no_grad():
values, indices = torch.sort(decoder_probs, dim=1)
accum_values = torch.cumsum(values, dim=1)
num_drops = (accum_values < 1 - p).long().sum(1)
cutoffs = values.gather(1, num_drops.unsqueeze(1))
values = torch.where(decoder_probs >= cutoffs, decoder_probs, torch.zeros_like(values))
return values
def do_sample(self, cagm, tokenizer, input_tokens, init_lm_score, init_past,
min_length=5, max_length=36, p=0.5, device='cuda'):
blank_token_ids = tokenizer.convert_tokens_to_ids(['[[[BLANK%d]]]' % i for i in range(20)])
sep_token_id, = tokenizer.convert_tokens_to_ids(['[[[SEP]]]'])
answer_token_id, = tokenizer.convert_tokens_to_ids(['[[[ANSWER]]]'])
word_token_ids = tokenizer.convert_tokens_to_ids(['[[[WORD%d]]]' % i for i in range(20)])
eos_token_id = tokenizer.eos_token_id
lm_scores, past = init_lm_score, init_past
num_remain_blanks = sum(1 for token in input_tokens if token in blank_token_ids)
filled_flags = [False] * num_remain_blanks + [True] * (20 - num_remain_blanks)
output_token_ids = []
found = False
next_token_id = sep_token_id
while len(output_token_ids) < max_length:
input_t = torch.tensor([next_token_id], device=device, dtype=torch.long).unsqueeze(0)
with torch.no_grad():
outputs = cagm.model(input_ids=input_t, past_key_values=past)
lm_scores, past = outputs.logits, outputs.past_key_values
probs = F.softmax(lm_scores[:, 0], dim=1)
if num_remain_blanks > 0:
probs[:, eos_token_id] = 0.0
probs[:, answer_token_id] = 0.0
probs[:, eos_token_id] = 0.0
for i, flag in enumerate(filled_flags):
if flag:
probs[:, word_token_ids[i]] = 0.0
probs = probs / probs.sum()
filtered_probs = self.topp_filter(probs, p=p)
next_token_id = torch.multinomial(filtered_probs, 1).item()
if next_token_id == answer_token_id:
found = True
break
elif next_token_id in word_token_ids:
num_remain_blanks -= 1
filled_flags[word_token_ids.index(next_token_id)] = True
output_token_ids.append(next_token_id)
if not found or len(output_token_ids) < min_length:
return
output_token_ids = input_tokens + [sep_token_id] + output_token_ids
#logger.info(len(output_token_ids))
return self.format_output(tokenizer, output_token_ids)