Source code for openbackdoor.defenders.defender

from typing import *
from openbackdoor.victims import Victim
from openbackdoor.utils import evaluate_detection, logger
import torch
import torch.nn as nn


[docs]class Defender(object): """ The base class of all defenders. Args: name (:obj:`str`, optional): the name of the defender. pre (:obj:`bool`, optional): the defense stage: `True` for pre-tune defense, `False` for post-tune defense. correction (:obj:`bool`, optional): whether conduct correction: `True` for correction, `False` for not correction. metrics (:obj:`List[str]`, optional): the metrics to evaluate. """ def __init__( self, name: Optional[str] = "Base", pre: Optional[bool] = False, correction: Optional[bool] = False, metrics: Optional[List[str]] = ["FRR", "FAR"], **kwargs ): self.name = name self.pre = pre self.correction = correction self.metrics = metrics
[docs] def detect(self, model: Optional[Victim] = None, clean_data: Optional[List] = None, poison_data: Optional[List] = None): """ Detect the poison data. Args: model (:obj:`Victim`): the victim model. clean_data (:obj:`List`): the clean data. poison_data (:obj:`List`): the poison data. Returns: :obj:`List`: the prediction of the poison data. """ return [0] * len(poison_data)
[docs] def correct(self, model: Optional[Victim] = None, clean_data: Optional[List] = None, poison_data: Optional[Dict] = None): """ Correct the poison data. Args: model (:obj:`Victim`): the victim model. clean_data (:obj:`List`): the clean data. poison_data (:obj:`List`): the poison data. Returns: :obj:`List`: the corrected poison data. """ return poison_data
[docs] def eval_detect(self, model: Optional[Victim] = None, clean_data: Optional[List] = None, poison_data: Optional[Dict] = None): """ Evaluate defense. Args: model (:obj:`Victim`): the victim model. clean_data (:obj:`List`): the clean data. poison_data (:obj:`List`): the poison data. Returns: :obj:`Dict`: the evaluation results. """ score = {} for key, dataset in poison_data.items(): preds = self.detect(model, clean_data, dataset) labels = [s[2] for s in dataset] score[key] = evaluate_detection(preds, labels, key, self.metrics) return score, preds
def get_target_label(self, data): for d in data: if d[2] == 1: return d[1]