from openbackdoor.victims import Victim
from openbackdoor.utils import logger, evaluate_classification
from openbackdoor.data import get_dataloader, wrap_dataset
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from datetime import datetime
import torch.nn as nn
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
from typing import *
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from umap import UMAP
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
[docs]class Trainer(object):
r"""
Basic clean trainer. Used in clean-tuning and dataset-releasing attacks.
Args:
name (:obj:`str`, optional): name of the trainer. Default to "Base".
lr (:obj:`float`, optional): learning rate. Default to 2e-5.
weight_decay (:obj:`float`, optional): weight decay. Default to 0.
epochs (:obj:`int`, optional): number of epochs. Default to 10.
batch_size (:obj:`int`, optional): batch size. Default to 4.
gradient_accumulation_steps (:obj:`int`, optional): gradient accumulation steps. Default to 1.
max_grad_norm (:obj:`float`, optional): max gradient norm. Default to 1.0.
warm_up_epochs (:obj:`int`, optional): warm up epochs. Default to 3.
ckpt (:obj:`str`, optional): checkpoint name. Can be "best" or "last". Default to "best".
save_path (:obj:`str`, optional): path to save the model. Default to "./models/checkpoints".
loss_function (:obj:`str`, optional): loss function. Default to "ce".
visualize (:obj:`bool`, optional): whether to visualize the hidden states. Default to False.
poison_setting (:obj:`str`, optional): the poisoning setting. Default to mix.
poison_method (:obj:`str`, optional): name of the poisoner. Default to "Base".
poison_rate (:obj:`float`, optional): the poison rate. Default to 0.1.
"""
def __init__(
self,
name: Optional[str] = "Base",
lr: Optional[float] = 2e-5,
weight_decay: Optional[float] = 0.,
epochs: Optional[int] = 10,
batch_size: Optional[int] = 4,
gradient_accumulation_steps: Optional[int] = 1,
max_grad_norm: Optional[float] = 1.0,
warm_up_epochs: Optional[int] = 3,
ckpt: Optional[str] = "best",
save_path: Optional[str] = "./models/checkpoints",
loss_function: Optional[str] = "ce",
visualize: Optional[bool] = False,
poison_setting: Optional[str] = "mix",
poison_method: Optional[str] = "Base",
poison_rate: Optional[float] = 0.01,
**kwargs):
self.name = name
self.lr = lr
self.weight_decay = weight_decay
self.epochs = epochs
self.batch_size = batch_size
self.warm_up_epochs = warm_up_epochs
self.ckpt = ckpt
timestamp = int(datetime.now().timestamp())
self.save_path = os.path.join(save_path, f'{poison_setting}-{poison_method}-{poison_rate}', str(timestamp))
os.makedirs(self.save_path, exist_ok=True)
self.visualize = visualize
self.poison_setting = poison_setting
self.poison_method = poison_method
self.poison_rate = poison_rate
self.COLOR = ['royalblue', 'red', 'palegreen', 'violet', 'paleturquoise',
'green', 'mediumpurple', 'gold', 'deepskyblue']
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
if loss_function == "ce":
reduction = "none" if self.visualize else "mean"
self.loss_function = nn.CrossEntropyLoss(reduction=reduction)
[docs] def register(self, model: Victim, dataloader, metrics):
r"""
Register model, dataloader and optimizer
"""
self.model = model
self.metrics = metrics
self.main_metric = self.metrics[0]
self.split_names = dataloader.keys()
self.model.train()
self.model.zero_grad()
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
train_length = len(dataloader["train"])
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
num_warmup_steps=self.warm_up_epochs * train_length,
num_training_steps=self.epochs * train_length)
self.poison_loss_all = []
self.normal_loss_all = []
if self.visualize:
poison_loss_before_tuning, normal_loss_before_tuning = self.comp_loss(model, dataloader["train"])
self.poison_loss_all.append(poison_loss_before_tuning)
self.normal_loss_all.append(normal_loss_before_tuning)
self.hidden_states, self.labels, self.poison_labels = self.compute_hidden(model, dataloader["train"])
# Train
logger.info("***** Training *****")
logger.info(" Num Epochs = %d", self.epochs)
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", self.epochs * train_length)
[docs] def train_one_epoch(self, epoch: int, epoch_iterator):
"""
Train one epoch function.
Args:
epoch (:obj:`int`): current epoch.
epoch_iterator (:obj:`torch.utils.data.DataLoader`): dataloader for training.
Returns:
:obj:`float`: average loss of the epoch.
"""
self.model.train()
total_loss = 0
poison_loss_list, normal_loss_list = [], []
for step, batch in enumerate(epoch_iterator):
batch_inputs, batch_labels = self.model.process(batch)
output = self.model(batch_inputs)
logits = output.logits
loss = self.loss_function(logits, batch_labels)
if self.visualize:
poison_labels = batch["poison_label"]
for l, poison_label in zip(loss, poison_labels):
if poison_label == 1:
poison_loss_list.append(l.item())
else:
normal_loss_list.append(l.item())
loss = loss.mean()
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
loss.backward()
if (step + 1) % self.gradient_accumulation_steps == 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
self.model.zero_grad()
avg_loss = total_loss / len(epoch_iterator)
avg_poison_loss = sum(poison_loss_list) / len(poison_loss_list) if self.visualize else 0
avg_normal_loss = sum(normal_loss_list) / len(normal_loss_list) if self.visualize else 0
return avg_loss, avg_poison_loss, avg_normal_loss
[docs] def train(self, model: Victim, dataset, metrics: Optional[List[str]] = ["accuracy"]):
"""
Train the model.
Args:
model (:obj:`Victim`): victim model.
dataset (:obj:`Dict`): dataset.
metrics (:obj:`List[str]`, optional): list of metrics. Default to ["accuracy"].
Returns:
:obj:`Victim`: trained model.
"""
dataloader = wrap_dataset(dataset, self.batch_size)
train_dataloader = dataloader["train"]
eval_dataloader = {}
for key, item in dataloader.items():
if key.split("-")[0] == "dev":
eval_dataloader[key] = dataloader[key]
self.register(model, dataloader, metrics)
best_dev_score = 0
for epoch in range(self.epochs):
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
epoch_loss, poison_loss, normal_loss = self.train_one_epoch(epoch, epoch_iterator)
self.poison_loss_all.append(poison_loss)
self.normal_loss_all.append(normal_loss)
logger.info('Epoch: {}, avg loss: {}'.format(epoch+1, epoch_loss))
dev_results, dev_score = self.evaluate(self.model, eval_dataloader, self.metrics)
if self.visualize:
hidden_state, labels, poison_labels = self.compute_hidden(model, epoch_iterator)
self.hidden_states.extend(hidden_state)
self.labels.extend(labels)
self.poison_labels.extend(poison_labels)
if dev_score > best_dev_score:
best_dev_score = dev_score
if self.ckpt == 'best':
torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
if self.visualize:
self.save_vis()
if self.ckpt == 'last':
torch.save(self.model.state_dict(), self.model_checkpoint(self.ckpt))
logger.info("Training finished.")
state_dict = torch.load(self.model_checkpoint(self.ckpt))
self.model.load_state_dict(state_dict)
# test_score = self.evaluate_all("test")
return self.model
[docs] def evaluate(self, model, eval_dataloader, metrics):
"""
Evaluate the model.
Args:
model (:obj:`Victim`): victim model.
eval_dataloader (:obj:`torch.utils.data.DataLoader`): dataloader for evaluation.
metrics (:obj:`List[str]`, optional): list of metrics. Default to ["accuracy"].
Returns:
results (:obj:`Dict`): evaluation results.
dev_score (:obj:`float`): dev score.
"""
results, dev_score = evaluate_classification(model, eval_dataloader, metrics)
return results, dev_score
[docs] def compute_hidden(self, model: Victim, dataloader: torch.utils.data.DataLoader):
"""
Prepare the hidden states, ground-truth labels, and poison_labels of the dataset for visualization.
Args:
model (:obj:`Victim`): victim model.
dataloader (:obj:`torch.utils.data.DataLoader`): non-shuffled dataloader for train set.
Returns:
hidden_state (:obj:`List`): hidden state of the training data.
labels (:obj:`List`): ground-truth label of the training data.
poison_labels (:obj:`List`): poison label of the poisoned training data.
"""
logger.info('***** Computing hidden hidden_state *****')
model.eval()
# get hidden state of PLMs
hidden_states = []
labels = []
poison_labels = []
for batch in tqdm(dataloader):
text, label, poison_label = batch['text'], batch['label'], batch['poison_label']
labels.extend(label)
poison_labels.extend(poison_label)
batch_inputs, _ = model.process(batch)
output = model(batch_inputs)
hidden_state = output.hidden_states[-1] # we only use the hidden state of the last layer
try: # bert
pooler_output = getattr(model.plm, model.model_name.split('-')[0]).pooler(hidden_state)
except: # RobertaForSequenceClassification has no pooler
dropout = model.plm.classifier.dropout
dense = model.plm.classifier.dense
try:
activation = model.plm.activation
except:
activation = torch.nn.Tanh()
pooler_output = activation(dense(dropout(hidden_state[:, 0, :])))
hidden_states.extend(pooler_output.detach().cpu().tolist())
model.train()
return hidden_states, labels, poison_labels
[docs] def visualization(self, hidden_states: List, labels: List, poison_labels: List, fig_basepath: Optional[str]="./visualization", fig_title: Optional[str]="vis"):
"""
Visualize the latent representation of the victim model on the poisoned dataset and save to 'fig_basepath'.
Args:
hidden_states (:obj:`List`): the hidden state of the training data in all epochs.
labels (:obj:`List`): ground-truth label of the training data.
poison_labels (:obj:`List`): poison label of the poisoned training data.
fig_basepath (:obj:`str`, optional): dir path to save the model. Default to "./visualization".
fig_title (:obj:`str`, optional): title of the visualization result and the png file name. Default to "vis".
"""
logger.info('***** Visulizing *****')
dataset_len = int(len(poison_labels) / (self.epochs+1))
hidden_states= np.array(hidden_states)
labels = np.array(labels)
poison_labels = np.array(poison_labels, dtype=np.int64)
num_classes = len(set(labels))
for epoch in tqdm(range(self.epochs+1)):
fig_title = f'Epoch {epoch}'
hidden_state = hidden_states[epoch*dataset_len : (epoch+1)*dataset_len]
label = labels[epoch*dataset_len : (epoch+1)*dataset_len]
poison_label = poison_labels[epoch*dataset_len : (epoch+1)*dataset_len]
poison_idx = np.where(poison_label==np.ones_like(poison_label))[0]
embedding_umap = self.dimension_reduction(hidden_state)
embedding = pd.DataFrame(embedding_umap)
for c in range(num_classes):
idx = np.where(label==int(c)*np.ones_like(label))[0]
idx = list(set(idx) ^ set(poison_idx))
plt.scatter(embedding.iloc[idx,0], embedding.iloc[idx,1], c=self.COLOR[c], s=1, label=c)
plt.scatter(embedding.iloc[poison_idx,0], embedding.iloc[poison_idx,1], s=1, c='gray', label='poison')
plt.tick_params(labelsize='large', length=2)
plt.legend(fontsize=14, markerscale=5, loc='lower right')
os.makedirs(fig_basepath, exist_ok=True)
plt.savefig(os.path.join(fig_basepath, f'{fig_title}.png'))
plt.savefig(os.path.join(fig_basepath, f'{fig_title}.pdf'))
fig_path = os.path.join(fig_basepath, f'{fig_title}.png')
logger.info(f'Saving png to {fig_path}')
plt.close()
return embedding_umap
def dimension_reduction(self, hidden_states: List,
pca_components: Optional[int] = 20,
n_neighbors: Optional[int] = 100,
min_dist: Optional[float] = 0.5,
umap_components: Optional[int] = 2):
pca = PCA(n_components=pca_components,
random_state=42,
)
umap = UMAP( n_neighbors=n_neighbors,
min_dist=min_dist,
n_components=umap_components,
random_state=42,
transform_seed=42,
)
embedding_pca = pca.fit_transform(hidden_states)
embedding_umap = umap.fit(embedding_pca).embedding_
return embedding_umap
[docs] def clustering_metric(self, hidden_states: List, poison_labels: List, save_path: str):
"""
Compute the 'davies bouldin scores' for hidden states to track whether the poison samples can cluster together.
Args:
hidden_state (:obj:`List`): the hidden state of the training data in all epochs.
poison_labels (:obj:`List`): poison label of the poisoned training data.
save_path (:obj: `str`): path to save results.
"""
# dimension reduction
dataset_len = int(len(poison_labels) / (self.epochs+1))
hidden_states = np.array(hidden_states)
davies_bouldin_scores = []
for epoch in range(self.epochs+1):
hidden_state = hidden_states[epoch*dataset_len : (epoch+1)*dataset_len]
poison_label = poison_labels[epoch*dataset_len : (epoch+1)*dataset_len]
davies_bouldin_scores.append(davies_bouldin_score(hidden_state, poison_label))
np.save(os.path.join(save_path, 'davies_bouldin_scores.npy'), np.array(davies_bouldin_scores))
result = pd.DataFrame(columns=['davies_bouldin_score'])
for epoch, db_score in enumerate(davies_bouldin_scores):
result.loc[epoch, :] = [db_score]
result.to_csv(os.path.join(save_path, f'davies_bouldin_score.csv'))
return davies_bouldin_scores
def comp_loss(self, model: Victim, dataloader: torch.utils.data.DataLoader):
poison_loss_list, normal_loss_list = [], []
for step, batch in enumerate(dataloader):
batch_inputs, batch_labels = self.model.process(batch)
output = self.model(batch_inputs)
logits = output.logits
loss = self.loss_function(logits, batch_labels)
poison_labels = batch["poison_label"]
for l, poison_label in zip(loss, poison_labels):
if poison_label == 1:
poison_loss_list.append(l.item())
else:
normal_loss_list.append(l.item())
avg_poison_loss = sum(poison_loss_list) / len(poison_loss_list) if self.visualize else 0
avg_normal_loss = sum(normal_loss_list) / len(normal_loss_list) if self.visualize else 0
return avg_poison_loss, avg_normal_loss
def plot_curve(self, davies_bouldin_scores, normal_loss, poison_loss, fig_basepath: Optional[str]="./learning_curve", fig_title: Optional[str]="fig"):
# bar of db score
fig, ax1 = plt.subplots()
ax1.bar(range(self.epochs+1), davies_bouldin_scores, width=0.5, color='royalblue', label='davies bouldin score')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Davies Bouldin Score', size=14)
# curve of loss
ax2 = ax1.twinx()
ax2.plot(range(self.epochs+1), normal_loss, linewidth=1.5, color='green',
label=f'Normal Loss')
ax2.plot(range(self.epochs+1), poison_loss, linewidth=1.5, color='orange',
label=f'Poison Loss')
ax2.set_ylabel('Loss', size=14)
plt.title('Clustering Performance', size=14)
os.makedirs(fig_basepath, exist_ok=True)
plt.savefig(os.path.join(fig_basepath, f'{fig_title}.png'))
plt.savefig(os.path.join(fig_basepath, f'{fig_title}.pdf'))
fig_path = os.path.join(fig_basepath, f'{fig_title}.png')
logger.info(f'Saving png to {fig_path}')
plt.close()
def save_vis(self):
hidden_path = os.path.join('./hidden_states',
self.poison_setting, self.poison_method, str(self.poison_rate))
os.makedirs(hidden_path, exist_ok=True)
np.save(os.path.join(hidden_path, 'all_hidden_states.npy'), np.array(self.hidden_states))
np.save(os.path.join(hidden_path, 'labels.npy'), np.array(self.labels))
np.save(os.path.join(hidden_path, 'poison_labels.npy'), np.array(self.poison_labels))
embedding = self.visualization(self.hidden_states, self.labels, self.poison_labels,
fig_basepath=os.path.join('./visualization', self.poison_setting, self.poison_method, str(self.poison_rate)))
np.save(os.path.join(hidden_path, 'embedding.npy'), embedding)
curve_path = os.path.join('./learning_curve', self.poison_setting, self.poison_method, str(self.poison_rate))
os.makedirs(curve_path, exist_ok=True)
davies_bouldin_scores = self.clustering_metric(self.hidden_states, self.poison_labels, curve_path)
np.save(os.path.join(curve_path, 'poison_loss.npy'), np.array(self.poison_loss_all))
np.save(os.path.join(curve_path, 'normal_loss.npy'), np.array(self.normal_loss_all))
self.plot_curve(davies_bouldin_scores, self.poison_loss_all, self.normal_loss_all,
fig_basepath=curve_path)
def model_checkpoint(self, ckpt: str):
return os.path.join(self.save_path, f'{ckpt}.ckpt')