ssfinetuning package

Submodules

ssfinetuning.dataset_utils module

class ssfinetuning.dataset_utils.SimpleDataset(dataset)

Bases: Generic[torch.utils.data.dataset.T_co]

A simple dataset for utilities in Co-Training and Tri-Training.

Args:

dataset (Union[SimpleDataset, dataset]): Can be SimpleDataset object or pyarrow based dataset object.

Class attributes:

-original_len: Length of the dataset at the instantiation.

-to_append_dic: The dictionary used in appending unlabeled examples in dataset.

-batch_masks: This dictionary keeps track of the unlabeled examples which are removed and inserted in the dataset during appending procedure.

append(ul_data, mask=None, batch_index=None)

Function used during the appending procedure.

Args:

ul_data (torch.FloatTensor): Unlabeled data batch.

mask (torch.BoolTensor): Mask of the data object which are going to accepted from the batch. This object also helps in keeping track of the examples which are inserted.

batch_index (:obj: ´int´): Index of the batch of unlabeled data. To be used by batch_masks dictionary.

Return: mask_change.sum() (:obj: ´int´): Sum of any insertion and deletion of examples in the dataset.

extend_length(length)

Extends the length of the dataset by randomly repeating length amount of rows.

reformat()

After appending using usual list appending, dataset is reformat to huggingface dataset format.

reset()

Resets the dataset to the original length at the instantiation.

ssfinetuning.dataset_utils.dic_to_pandas(history, loss_key='eval_loss', accuracy_measure='eval_matthews_correlation')

Function to convert the list of dictionary to pandas DataFrame which is easier for the plotting function in plotting utils to handle

Args:

history (list): A list of history dictionaries. It is basically transformer.TrainerState() at different hyperparameters analysed.

loss_key (str): The key to look for. In the case of analysis of evaluation history, the key is ‘eval_loss’. In the case of analysis of training history the key is ‘train_loss’

accuracy_measure (str): Name of the metric used during evaluation.

ssfinetuning.dataset_utils.extract_keys(function, kwargs, remove_from_orignal=True)

Function which extract the keys of “function” from “kwargs” by inspecting the signature of “function”.

ssfinetuning.dataset_utils.match_with_batchsize(lim, batchsize)

Function used by modify_datasets below to match return the integer closest to lim which is multiple of batchsize, i.e., lim%batchsize=0.

ssfinetuning.dataset_utils.modify_datasets(dataset, labeled_fr=0.5, model_type='TriTrain', labeled1_frac=0.33, train_key='train', label_column='label', unlabeled_labels=- 1, batchsize=16)

Function to modify pyarrow based datasets (huggingface dataset) for testing at different fraction of labeled vs unlabeled data.

Args:

dataset (dataset.DatasetDict): Dictionary containing training and validation datasets.

labeled_frac (float): Fraction of training dataset to be kept as labeled dataset and rest will be divided as unlabeled dataset.

model_type (str): Semi supervised model type.

labeled1_frac (float): In the case of CoTraining and TriTraining model_type, this is the fraction given to the first two models (m1 and m2) after being divided by labeled_fr. Rest is given to model 3. For example, labeled1_frac=0.33, m1 and m2 gets 0.33 and m3 gets (1-2*0.33)

train_key (str): Key value of where training data is accessed.

label_column (str): Key value of where columns for labels.

unlabeled_labels (int): Value to be assigned to the unlabeled dataset labels, required for Pi, TemporalEnsemble, and MeanTeacher as they need to know which ones are unlabeled examples.

batchsize (int): Batch size used during training.

Return: dataset.DatasetDict object with labeled and unlabeled data.

ssfinetuning.default_args module

class ssfinetuning.default_args.DefaultArgs

Bases: object

get_default_ta(logging_dir='')

Return the TrainingArguments with the logging_dir setup for semisupervised model.

get_default_ta_sup(logging_dir='')

Return the TrainingArguments with the logging_dir setup for supervised model.

set_default_args(dataset, model_name, kwargs)

Function for setting the default arguments if these keywords are not provided in kwargs of train_with_ssl. Updates the kwargs to be used transformers.Trainer. Args: dataset (DatasetDict ): Dataset dictionary containing labeled and unlabelled data. model_name (str or os.PathLike): “pretrained_model_name_or_path” in ~transformers.PreTrainedModel, please refer to its documentation for further information. (i) In this case of a string, the model id of a pretrained model hosted inside a model repo on huggingface.co. (ii) It could also be address of saved pretrained model. kwargs (dict): keyword arguments to be used by transformers.Trainer.

ssfinetuning.default_args.encode(dataset, model_name='albert-base-v2', text_column_name='sentence')

Function for encoding the dataset using tokenizer

Args:

dataset (DatasetDict ): Dataset dictionary containing labeled and unlabeled data.

model_name (str or os.PathLike): “pretrained_model_name_or_path” in ~transformers.PreTrainedModel, please refer to its documentation for further information.

(i) In this case of a string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  1. It could also be address of saved pretrained model.

text_column_name (str): column name for where the text is.

Return:

encoded_dataset (DatasetDict ): containing columns now which are required by the forward function.

tokenizer (PreTokenizer ):

ssfinetuning.default_args.get_default_cm()

Default compute metric function.

ssfinetuning.models module

class ssfinetuning.models.BaseModelClass(model_name='albert-base-v2', supervised_run=False, num_labels=2, classifier_dropout=0.1, num_models=1, ssl_model_type=None)

Bases: torch.nn.modules.module.Module

Base class for all model with single pretrained model, but might have multiple classifier layers.

Args:

model_name (str or os.PathLike): “pretrained_model_name_or_path” in ~transformers.PreTrainedModel, please refer to its documentation for further information.

(i) In this case of a string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  1. It could also be address of saved pretrained model.

supervised_run (bool): If the model is taken from the supervised run or not. In that case transformer_model_name is the path to the saved model.

num_labels (int): number of labels to be classified.

classifier_dropout (float): dropout probability of the classifier layers.

num_models (int): number of models, i.e. number of classifier layers (only set by the sub classes).

ssl_model_type (str): semi supervised learning model type (only set by the sub classes).

simple_forward_with_prob_logits(classifier_num=0, **kwargs)

This function first changes the pointer of the pretrained_model to the one of the classifier defined in this class. Then applies softmax to it and thus converts it to probability logits.

Args:

classifier_num: Index of the classifier to be used.

kwargs: Arguments from pretrained_model.forward.

Return:

logits ( torch.FloatTensor): probability logits.

training: bool
class ssfinetuning.models.BaseMultiPretrained(teacher_student_name=('albert-base-v2', 'albert-base-v2'), num_labels=2, teacher_dropout=None, student_dropout=None, ssl_model_type=None)

Bases: torch.nn.modules.module.Module

Base class for all models with multiple pretrained model.

Args:

ssl_model_type (str): semi supervised learning model type.

teacher_student_name (Tuple[`str, str]): A Tuple for teacher and student name, respectively. “pretrained_model_name_or_path” in ~transformers.PreTrainedModel, please refer to its documentation for further information.

(i) In this case of a string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  1. It could also be address of saved pretrained model.

num_labels (int): number of labels to be classified.

student_dropout (float): dropout probability of the student classifier layers.

teacher_dropout (float): dropout probability of the teacher classifier layers.

training: bool
class ssfinetuning.models.CoTrain(o_weight=0.01, num_labels=2, model_name='albert-base-v2', classifier_dropout=0.1, ssl_model_type='CoTrain', num_models=2, supervised_run=False)

Bases: ssfinetuning.models.BaseModelClass

Implementation of Co Training as introduced in <https://www.cs.cmu.edu/~avrim/Papers/cotrain.pdf>

Args:

o_weight (float): Orthogonality weight for the two classifiers (or two models).

kwargs: remaining dictionary of keyword arguments from the BaseModelClass.

cotrain_forward(model1_batch, model2_batch)

Forward function used during training of models. See ~trainer_utils.TrainerForCoTraining for more details.

Args: model1_batch (:obj: torch.FloatTensor) batch for model 1. model2_batch (:obj: torch.FloatTensor) batch for model 2.

Return: CoTrainModelOutput object with the information of logits of both models and the loss function.

forward(**kwargs)

Forward function only used during the evaluation of models. See ~trainer_utils.TrainerForCoTraining and ~transformers.Trainer for more details.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: CoTrainModelOutput object with the information of logits of both models and the loss function.

training: bool
class ssfinetuning.models.CoTrainModelOutput(loss: Union[torch.FloatTensor, NoneType] = None, logits_m1: torch.FloatTensor = None, logits_m2: torch.FloatTensor = None)

Bases: transformers.file_utils.ModelOutput

logits_m1: torch.FloatTensor = None
logits_m2: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
class ssfinetuning.models.MeanTeacher(teacher_student_name=('albert-base-v2', 'albert-base-v2'), num_labels=2, unsup_weight=0, teacher_dropout=None, alpha=0.5, student_dropout=None)

Bases: ssfinetuning.models.BaseMultiPretrained

Implementation of Mean Teacher as introduced in <https://arxiv.org/abs/1703.01780>

Args:

alpha (float): memory of the last epochs.

unsup_weight: Initial unsupervised weight.

Class attributes:

-firstpass: bool variable to track if its the first pass through the forward method.

forward(**kwargs)

Implementation of forward function calculating the semi supervised loss. Mixing of the labeled and unlabeled examples in a single batch is not allowed.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: transformers.modeling_outputs.SequenceClassifierOutput object with the information of logits and the loss function.

training: bool
update_teacher_variables()

Function for updating teacher weights and bias. Directly used from <https://github.com/CuriousAI/mean-teacher>

zero_teacher_weights(module)

Function for zeroing the teachers weights and biases.

class ssfinetuning.models.NoisyStudent(teacher_dropout=None, student_dropout=None, num_labels=2, teacher_student_name=('albert-base-v2', 'albert-base-v2'))

Bases: ssfinetuning.models.BaseMultiPretrained

Implementation of Noisy Student as introduced in <https://arxiv.org/abs/1911.04252>

Args:

kwargs: keyword arguments are the same as is for BaseMultiPretrained class, except model type string. Class forward initialized with teacher as the teacher is trained first.

training: bool
class ssfinetuning.models.PiModel(unsup_weight=0, num_labels=2, model_name='albert-base-v2', classifier_dropout=0.1, supervised_run=False)

Bases: ssfinetuning.models.BaseModelClass

Implementation of pi model from <https://arxiv.org/abs/1610.02242>.

Args:

unsup_weight (float): Initial value of the weight of the unsupervised loss component. Its value is controlled by unsupervised weight scheduler.

kwargs: remaining dictionary of keyword arguments from the BaseModelClass.

forward(**kwargs)

Implementation of forward function calculating the semi supervised loss. Mixing of the labeled and unlabeled examples in a single batch is not allowed.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: transformers.modeling_outputs.SequenceClassifierOutput object with the information of logits and loss function.

training: bool
class ssfinetuning.models.TemporalEnsembleModel(unsup_weight=0, num_labels=2, model_name='albert-base-v2', alpha=0.5, classifier_dropout=0.1, supervised_run=False)

Bases: ssfinetuning.models.BaseModelClass

Implementation of Temporal ensemble model as introduced in <https://arxiv.org/abs/1610.02242>

Args:

alpha (float): memory of the last epochs. For more info please refer to <https://arxiv.org/abs/1610.02242>.

unsup_weight (float): initial value of weight of the unsupervised loss component. After setting the initial value, its value is controlled by unsupervised weight scheduler.

kwargs: remaining dictionary of keyword arguments from the BaseModelClass.

Class attributes:

-mini_batch_num: keeps track of the mini_batch_num using forward method.

-logits_batchwise: stores the logits of each batch passed through forward method.

-firstpass: bool variable to track if its the first pass through the forward method.

forward(**kwargs)

Implementation of forward function calculating the semi supervised loss. Mixing of the labeled and unlabeled examples in a single batch is not allowed.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: transformers.modeling_outputs.SequenceClassifierOutput object with the information of logits and the loss function.

training: bool
update_memory_logits(t)

Method for updating the memory logits with the exponential average.

Args: t (int): epoch value for bias normalization.

class ssfinetuning.models.TriTrain(o_weight=0.01, num_labels=2, classifier_dropout=0.1, model_name='albert-base-v2', ssl_model_type='CoTrain', supervised_run=False)

Bases: ssfinetuning.models.CoTrain

Implementation of Tri Training(multi task TriTrain) as introduced in <https://arxiv.org/abs/1804.09530>. Note: Here the implementation is at only the fine tuning. The base network is to be pretrained transformer model.

Args:

kwargs: keyword arguments are the same as is for CoTrain class, except model type string and number of models(num_models) as obvious with the name.

forward(**kwargs)

Forward function used during evaluation of trained models. See ~trainer_utils.TrainerForTriTraining and ~transformers.Trainer for more details.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: TriTrainModelOutput object with the information of logits of both models and the loss function.

m3_forward(**kwargs)

Forward function for model 3. See ~trainer_utils.TrainerForTriTraining and ~transformers.Trainer for more details.

Args:

kwargs: Arguments from pretrained_model.forward.

Return: TriTrainModelOutput object with the information of logits of both models and the loss function.

training: bool
class ssfinetuning.models.TriTrainModelOutput(loss: Union[torch.FloatTensor, NoneType] = None, logits_m1: torch.FloatTensor = None, logits_m2: torch.FloatTensor = None, logits_m3: torch.FloatTensor = None)

Bases: transformers.file_utils.ModelOutput

logits_m1: torch.FloatTensor = None
logits_m2: torch.FloatTensor = None
logits_m3: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
ssfinetuning.models.add_signature_from(base)

ssfinetuning.plotting_utils module

ssfinetuning.plotting_utils.add_end_args(from_fn)
ssfinetuning.plotting_utils.get_default_legend_pos(num_graphs, axes_index=None)

Sets the default values where legends could be placed.

Args:

num_graphs (int ): num of num_graphs to be plotted with maximum value of 4.

axes_index (int , optional, defaults to None ): In the case of multiple, setting changed depending on index of axes.

ssfinetuning.plotting_utils.plot_in(axes, axes_index=0, totplots=1, data=None, data_to_compare=None, x_axis_col='epoch', y_axis_col='eval_mc', select_best=5, criteria='max', cols_to_find=['w_ramprate'], dis_col='l_fr', dis_val=False, data_to_compare_lb='sup_stats')

Main plotting function.

Args:

axes (matplotlib.pyplot.axes ): axes object to plot.

axes_index (int, optional, defaults to None ): In the case of multiple, setting changed depending on index of axes.

totplots (obj: int): Total number of plots to be plotted.

data (pd.DataFrame ): Data to sort from.

data_to_compare (pd.DataFrame ): Data to compare with the sorted results. For example, purely supervised results.

x_axis_col (str, optional, defaults to ‘epoch’ ): Column name with the values to be plotted on the x axis.

y_axis_col (str, optional, defaults to ‘eval_mc’ ): Column name with the values to be plotted on the y axis.

select_best (int, optional, defaults to 5 ): The number of plots to be made based out of the sorted list.

criteria (str, optional, defaults to ‘max’ ): Criteria to sort the list. There are three choices, (i) max, (ii) min, and (iii)mean.

cols_to_find (list, optional, defaults to [‘w_ramprate’] ): The list of column names which will analysed to find the best of them based the sorting criteria.

dis_col (str, optional, defaults to ‘l_fr’ ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots.

dis_val (int or ‘float’, optional, defaults to ‘None’ ): This is only valid if the ‘dis_col’ is not None. This is used when a certain unique value of discriminatory column is plotted.

data_to_compare_lb (str, optional, defaults to ‘sup_stats’ ): Label name for the data_to_compare plot.

ssfinetuning.plotting_utils.plot_with_discriminator(dis_col, save_png, data=None, *args, **kwargs)

Plotter if discriminatory column is specified.

Args:

dis_col (str ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots.

save_png (str ): Whether to save png of results or not. If the value of save_png is not None then it would save the image with name of string value set in save_png.

kwargs: remaining dictionary of keyword arguments from the plot_in function.

Adding Args from function-> plot_in

axes (matplotlib.pyplot.axes ): axes object to plot.

axes_index (int, optional, defaults to None ): In the case of multiple, setting changed depending on index of axes.

totplots (obj: int): Total number of plots to be plotted.

data (pd.DataFrame ): Data to sort from.

data_to_compare (pd.DataFrame ): Data to compare with the sorted results. For example, purely supervised results.

x_axis_col (str, optional, defaults to ‘epoch’ ): Column name with the values to be plotted on the x axis.

y_axis_col (str, optional, defaults to ‘eval_mc’ ): Column name with the values to be plotted on the y axis.

select_best (int, optional, defaults to 5 ): The number of plots to be made based out of the sorted list.

criteria (str, optional, defaults to ‘max’ ): Criteria to sort the list. There are three choices, (i) max, (ii) min, and (iii)mean.

cols_to_find (list, optional, defaults to [‘w_ramprate’] ): The list of column names which will analysed to find the best of them based the sorting criteria.

dis_col (str, optional, defaults to ‘l_fr’ ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots.

dis_val (int or ‘float’, optional, defaults to ‘None’ ): This is only valid if the ‘dis_col’ is not None. This is used when a certain unique value of discriminatory column is plotted.

data_to_compare_lb (str, optional, defaults to ‘sup_stats’ ): Label name for the data_to_compare plot.

ssfinetuning.plotting_utils.set_default_vals(num_graphs)

Sets the default values for font size, line widths, markers etc.

Args:

num_graphs (int ): num of num_graphs to be plotted with maximum value of 4.

ssfinetuning.plotting_utils.simple_plot(save_png, *args, **kwargs)

Plotter if discriminatory column is not specified.

Args:

save_png (str ): Whether to save png of results or not. If the value of save_png is not None then it would save the image with name of string value set in save_png.

kwargs: remaining dictionary of keyword arguments from the plot_in function.

Adding Args from function-> plot_in

axes (matplotlib.pyplot.axes ): axes object to plot.

axes_index (int, optional, defaults to None ): In the case of multiple, setting changed depending on index of axes.

totplots (obj: int): Total number of plots to be plotted.

data (pd.DataFrame ): Data to sort from.

data_to_compare (pd.DataFrame ): Data to compare with the sorted results. For example, purely supervised results.

x_axis_col (str, optional, defaults to ‘epoch’ ): Column name with the values to be plotted on the x axis.

y_axis_col (str, optional, defaults to ‘eval_mc’ ): Column name with the values to be plotted on the y axis.

select_best (int, optional, defaults to 5 ): The number of plots to be made based out of the sorted list.

criteria (str, optional, defaults to ‘max’ ): Criteria to sort the list. There are three choices, (i) max, (ii) min, and (iii)mean.

cols_to_find (list, optional, defaults to [‘w_ramprate’] ): The list of column names which will analysed to find the best of them based the sorting criteria.

dis_col (str, optional, defaults to ‘l_fr’ ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots.

dis_val (int or ‘float’, optional, defaults to ‘None’ ): This is only valid if the ‘dis_col’ is not None. This is used when a certain unique value of discriminatory column is plotted.

data_to_compare_lb (str, optional, defaults to ‘sup_stats’ ): Label name for the data_to_compare plot.

ssfinetuning.plotting_utils.sort_and_find(data, cols_unique_vals, x_axis_col, y_axis_col, select_best, criteria)

Finds a sorted list of “y_axis_col” in “data” based on the “criteria”. First, this function generates all the combinations of unique cols values and then it creates a list of values for all the combinations. At the end, it sorts this list based on the criteria.

Args:

data (pd.DataFrame ): The data to be sorted.

cols_unique_vals (list): A list of columns of important hyperparameters with their unique values.

x_axis_col (str): Column name with the values to be plotted on the x axis.

y_axis_col (str): Column name with the values to be plotted on the y axis.

criteria (str): criteria to sort the list. There are three choices, (i) max, (ii) min, and (iii)mean.

Return:

sorted_list (str):

ssfinetuning.plotting_utils.sort_and_plot(dis_col=None, save_png='results.png', *args, **kwargs)

Function to sort the results and plot them depending if discriminatory is specified or not.

Args:

dis_col (str, optional, defaults to None ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots. If it is None, it will simply plot a sorted values.

save_png (str ): Whether to save png of results or not. If the value of save_png is not None then it would save the image with name of string value set in save_png.

kwargs: remaining dictionary of keyword arguments for the plot_in function.

Adding Args from function-> plot_in

axes (matplotlib.pyplot.axes ): axes object to plot.

axes_index (int, optional, defaults to None ): In the case of multiple, setting changed depending on index of axes.

totplots (obj: int): Total number of plots to be plotted.

data (pd.DataFrame ): Data to sort from.

data_to_compare (pd.DataFrame ): Data to compare with the sorted results. For example, purely supervised results.

x_axis_col (str, optional, defaults to ‘epoch’ ): Column name with the values to be plotted on the x axis.

y_axis_col (str, optional, defaults to ‘eval_mc’ ): Column name with the values to be plotted on the y axis.

select_best (int, optional, defaults to 5 ): The number of plots to be made based out of the sorted list.

criteria (str, optional, defaults to ‘max’ ): Criteria to sort the list. There are three choices, (i) max, (ii) min, and (iii)mean.

cols_to_find (list, optional, defaults to [‘w_ramprate’] ): The list of column names which will analysed to find the best of them based the sorting criteria.

dis_col (str, optional, defaults to ‘l_fr’ ): The dicriminatory column name. This would be column name along which the graphs would be divided along the subplots.

dis_val (int or ‘float’, optional, defaults to ‘None’ ): This is only valid if the ‘dis_col’ is not None. This is used when a certain unique value of discriminatory column is plotted.

data_to_compare_lb (str, optional, defaults to ‘sup_stats’ ): Label name for the data_to_compare plot.

ssfinetuning.trainer_util module

class ssfinetuning.trainer_util.BaseForMMTrainer(model: Optional[Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module]] = None, args: Optional[transformers.training_args.TrainingArguments] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[torch.utils.data.dataset.Dataset] = None, eval_dataset: Optional[torch.utils.data.dataset.Dataset] = None, tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], transformers.modeling_utils.PreTrainedModel]] = None, compute_metrics: Optional[Callable[[transformers.trainer_utils.EvalPrediction], Dict]] = None, callbacks: Optional[List[transformers.trainer_callback.TrainerCallback]] = None, optimizers: Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None))

Bases: ssfinetuning.trainer_util.RemoveUnusedColumnMixing, transformers.trainer.Trainer

Base class for all the mutimodel trainers. This class contains the methods which are used by the Trainers which helps in training the semi supervised way which have mutiple models .

confi_prediction(logits_m1, logits_m2, logits_m3=None)

Prediction made based on between confidence of the models. First checks whichever model has the highest probability (confidence) on a given example and choses that class as the final answer.

Args: logits_m1 (torch.FloatTensor): logits recieved from model 1.

logits_m2 (torch.FloatTensor): logits from model 2.

logits_m3 (torch.FloatTensor): logits from model 3.

equate_lengths(model1_train, model2_train)

A method useful for CoTrain, TriTrain model training. It finds whichever model dataset has less training examples and then equates them using SimpleDataset.extend_length method

Args: model1_train (SimpleDataset): Training dataset for model 1.

model2_train (SimpleDataset): Training dataset for model 2.

get_dataloader(dataset, sequential=False)

Slightly changed ~transformers.Trainer.get_train_dataloader with the flexibility to change between sequential and RandomSampler.

post_epoch(step, epoch, tr_loss)

Collection of all the callback functions called after the epoch is done. See ~transformers.Trainer.train() for more details.

Args: step (int): step number, number of steps passed out of num_training_steps used in pre_train_init method.

epoch (int): epoch passed.

tr_loss(float): training loss.

pre_train_init(num_training_steps)

Collection of all the callback functions called before initializing the training. See ~transformers.Trainer.train() for more details.

Args: num_training_steps (int): total number of training which are calculated by number of mini batches per epoch * number of epochs.

prediction_step(model, inputs, prediction_loss_only, ignore_keys=None)

Slightly changed ~transformers.Trainer.prediction_step using confi_prediction method to find prediction of Cotrain and TriTrain Models. Used during evaluation step.

class ssfinetuning.trainer_util.RemoveUnusedColumnMixing

Bases: object

class ssfinetuning.trainer_util.TrainerForCoTraining(epoch_per_cotrain=2, exchange_threshold=20, ntimes_before_saturation=2, p_threshold=0.65, use_min_lr_scheduler=True, min_lr=1e-07, show_exchange=True, max_passes=1, *args, **kwargs)

Bases: ssfinetuning.trainer_util.BaseForMMTrainer

Subclass of ~transformers.Trainer with changes for Cotrain Model.

Args:

epoch_per_cotrain (int, optional, defaults to 2): Number of epochs to pass through training data while going through one iteration of cotraining.

exchange_threshold (int, optional, defaults to 20): Threshold value of exchange between the model, below which co training is stopped.

ntimes_before_saturation (int, optional, defaults to 2): In the case of get_linear_schedule_with_minlr, number of times should go over all model dataset + unlabeled dataset before saturating.

p_threshold (float): threshold probability for considering exchange between models.

use_min_lr_schedule (bool): Whether to use linear_schedule_with_minlr.

min_lr (int, optional, defaults to 1e-07): The value of minimum learning rate used by linear_schedule_with_minlr.

show_exchange (bool): Whether to print the exchange happening between models.

max_passes (int): Maximum number of passes through the all the datasets.

Note: dataset for training can be given to the trainer in two ways.

  1. dataset: In this case, it should the naming scheme of dataset_utils.modify_datasets.

(ii)dataset_model1, dataset_model2, unlabeled: dataset for model1, model2 and unlabeled dataset respectively.

Class attributes:

-total_dataset_len: total length of all datasets including model1, model2 and unlabeled.

-global_epoch: number of epoch globally including the epochs over model1_dataset+unlabeled and model2_dataset +unlabeled.

cotrain(model1_train, model2_train)

Method where the real co training takes place. Firstly, the dataset for the both models are equated with randomly repeated examples from training dataset, see SimpleDataset.extend_length(). Then, a batch from both model1 and model2 dataset is passed to cotrain_forward, where applies the orthogonality based between the classifier layers and calculates the total loss from both the batches.

Args: model1_train (SimpleDataset): Training dataset for model 1 after the exchange_unlabeled_data has been used.

model2_train (SimpleDataset): Training dataset for model 2 after the exchange_unlabeled_data has been used.

exchange_unlabeled_data()

Method to exchange the unlabeled dataset between models. Examples on which model1 is confident on(above p_threshold) are given to model2 to train on and vice versa.

Args: ul_dataloader (class:~torch.utils.data.DataLoader): dataloader for unlabeled dataset.

Returns: exchange (bool): If the number of the examples exchanged are above the exchange_threshold.

dataset_model1, dataset_model2 (SimpleDataset): dataset for model1 and model2 with added unlabeled data.

train()

Train method for CoTrainer. Performs exchange of unlabeled dataset between model1 and model2 until exchange condition is true, see exchange_unlabeled_data().

class ssfinetuning.trainer_util.TrainerForNoisyStudent(min_lr=1e-07, epoch_per_ts_iter=1, ts_iter=3, ntimes_before_saturation=2, reduce_init_lr_factor=1, use_min_lr_scheduler=None, *args, **kwargs)

Bases: ssfinetuning.trainer_util.BaseForMMTrainer

Subclass of ~transformers.Trainer for the noisy student model.

Args:

min_lr (int, optional, defaults to 1e-07): The value of minimum learning rate used by get_linear_schedule_with_minlr.

epoch_per_ts_iter (int, optional, defaults to 1): Number of epochs to pass during each teacher student iteration.

ts_iter (int, optional, defaults to 3): Number of teacher student iterations during training, in which student is again used as the teacher.

ntimes_before_saturation (int, optional, defaults to 2): In the case of get_linear_schedule_with_minlr, number of times should go over all model dataset + unlabeled dataset before saturating.

Note: dataset for training can be given to Trainer in two ways.

(i)dataset: In this case, it should the naming scheme of dataset_utils.modify_datasets.

(ii)dataset_labeled, dataset_unlabeled: dataset for labeled and unlabeled data.

exchange_models()

Method for changing student model into teacher model.

property num_training_steps_
psuedo_label()

Method for generating psuedo_label for by teacher model for student.

train()

Train method for Noisy student. Trains the teacher and student with exchange (replacing teacher with the student) at the end of every iteration. In this case both optimizer and learning rate scheduler at reinitiated at the end of an iteration.

train_and_reset(model)

Trains either teacher or student then resets the training variables like optimizer and scheduler.

train_one_model(model)

Common method for training either the teacher or student model.

class ssfinetuning.trainer_util.TrainerForTriTraining(procedure='agreement', epoch_per_tritrain=2, *args, **kwargs)

Bases: ssfinetuning.trainer_util.TrainerForCoTraining

Subclass of ~TrainerForCoTraining with adding of a third model for Tritrain Model.

Args:

procedure (str, optional, defaults to ‘agreement’): Whether to train with TriTraining with agreement <https://ieeexplore.ieee.org/document/1512038> or with disagreement <https://www.aclweb.org/anthology/P10-2038/>.

Note: dataset for training can be given to Trainer as same way as TrainerForCoTraining but with and addition of the third dataset for model3.

(i)dataset: In this case, it should the naming scheme of dataset_utils.modify_datasets.

(ii)dataset_model3: dataset for model3.

agreement_proc(labels_confi, la, lb, l_compare)

Function to implement agreement procedure during the exchange of unlabeled data.

disagreement_proc(labels_confi, la, lb, l_compare)

Function to implement disagreement procedure during the exchange of unlabeled data.

exchange_unlabeled_data(ul_dataloader)

Method to exchange the unlabeled dataset between models. Based on chosen exchange procedure. Rest is similar to TrainerForCoTraining.exchange_unlabeled_data().

Args: ul_dataloader (class:~torch.utils.data.DataLoader): dataloader for unlabeled dataset.

Returns: exchange (bool): If the number of the examples exchanged are above the exchange_threshold.

train()

Train method for TriTrainer. Performs exchange of unlabeled dataset between three models, model1, model2 and model3 until exchange condition is true, see exchange_unlabeled_data().

tri_train(model1_train, model2_train, model3_train)

Method for tri training. Same procedure as TrainerForCoTraining.co_train but with the addition of training for the third model.

class ssfinetuning.trainer_util.TrainerWithUWScheduler(kwargs_uw=None, *args, **kwargs)

Bases: ssfinetuning.trainer_util.RemoveUnusedColumnMixing, transformers.trainer.Trainer

Subclass of ~transformers.Trainer with minimal code change and integration with unsupervised weight scheduler.

Args:

kwargs_uw: dictionary of arguments to be used by UWScheduler.

kwargs: dictionary arguments for the ~transformers.Trainer, of dataset used by the trainer and could also include arguments of UWScheduler.

Note: dataset for training can be given to the trainer in two ways.

  1. dataset: In this case, it should the naming scheme of dataset_utils.modify_datasets.

(ii)train_dataset: Same naming scheme as used by ~transformers.Trainer.

check_for_consistency()

Checks if the labeled and unlabeled are present in the same minibatch, raises error if they are.

create_optimizer_and_scheduler(num_training_steps)

Overriden ~transformers.Trainer.create_optimizer_and_scheduler with integration with the UWScheduler to its Trainer.lr_scheduler object

get_train_dataloader()

Slightly changed ~transformers.Trainer.get_train_dataloader as models used in Trainer do not allow for mixing of labeled and unlabeled data. So changing to SequentialSampler instead of RandomSampler.

class ssfinetuning.trainer_util.UWScheduler(lr_scheduler, trainer, unsup_start_epochs=0, max_w=1, update_teacher_steps=False, w_ramprate=1, update_weights_steps=1)

Bases: object

Unsupervised weights scheduler for changing the unsupervised weight of the semi supervised learning models. Also contains methods any other kinds variables updates required by the models. For example, PiModel, TemporalEnsembling model, and Mean Teacher. In this implementation, it’s based on the composition with learning scheduler from pytorch and it works best ~transformers.Trainer without having to rewrite train method. # TODO: cleaner version with the rewritten Trainer.train method.

Args:

lr_scheduler (torch.optim.lr_scheduler): Learning scheduler object.

trainer (TrainerWithUWScheduler): Trainer object.

unsup_start_epochs (int): value of epoch at which the unsupervised weights should start updating.

max_w (float): maximum value of weight that the unsup_weight from model could reach.

update_teacher_steps (int): useful for MeanTeacher, sets the interval after which teacher variables should be updated.

w_ramprate (float): linear rate at which the unsupervised weight would be increased from the initial value.

update_weights_steps (int): interval steps after which unsupervised weight would be updated by the w_ramprate.

Class attributes:

-step_in_epochs: Number of steps (batch passes) in an epoch.

-local_step: keeps track of the times unsupervised weight has been changed.

is_true(value)

A simple checker function to if it is time to update things depending on the value of value.

step()

Implementation of composition of the pytorch learning rate scheduler step function with schedule of unsupervised weights. Also implements updating the memory logits for TemporalEnsembleModel and updating teacher variables for MeanTeacher model.

ssfinetuning.trainer_util.get_linear_schedule_with_minlr(optimizer: torch.optim.optimizer.Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = - 1, min_lr: int = 1e-07)

Creates a scheduler with a learning rate that linearly decreases but saturates at min_lr value.

Args:
optimizer (Optimizer):

The optimizer for which to schedule the learning rate.

num_warmup_steps (int):

The number of steps for the warmup phase.

num_training_steps (int):

The total number of training steps.

num_cycles (int, optional, defaults to 1):

The number of hard restarts to use.

last_epoch (int, optional, defaults to -1):

The index of the last epoch when resuming training.

min_lr (int, optional, defaults to 1e-07):

The value of minimum learning rate where it should saturate.

Return:

torch.optim.lr_scheduler.MultiplicativeLR with the appropriate schedule.

ssfinetuning.training_args module

ssfinetuning.training_args.check_and_replace(key, kwargs, args, basefunction)

Function to check for the “key” in “kwargs” or “args” and replace args if key is found in args. In case it’s not found either kwargs or args, it takes the default option from basefunction.

Args:

key (str): key to be searched.

kwargs (dict): keyword argument dictionary to look through.

args (tuple): arguments tuple to look through.

basefunction(:obj: train_with_ssl): Base function around which wrapper has been implemented.

ssfinetuning.training_args.generate_kwargs(hyperparam_dic)

Generator function for all combinations of hyperparameters from hyperparameter dictionary

Args:

hyperparam_dic (Dict ): Hyperparameter dictionary.

ssfinetuning.training_args.train_with_ssl(dataset=None, model_name='distilbert-base-uncased', ssl_model_type='PiModel', text_column_name='sentence', run_sup=False, use_sup=False, remove_dirs=True, teacher_student_name=None, num_labels=2, unsup_hp={'alpha': [0.3, 0.6, 0.9], 'w_ramprate': [0.001, 0.01, 0.1]}, sup_stats=None, stats=None, l_fr=False, **kwargs)

Function for training with semisupervised models during finetuning of the pretrained transformer models.

Args:

labeled_fraction (list ): Set up by wrapper_for_l_fr. List of labeled fraction of training data to be analysed. In this case original dataset is divided into the fraction of labeled dataset and unlabeled dataset.

dataset (DatasetDict ): Dataset dictionary containing labeled and unlabeled data.

model_name (str or os.PathLike): “pretrained_model_name_or_path” in ~transformers.PreTrainedModel, please refer to its documentation for further information.

(i) In this case of a string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  1. It could also be address of saved pretrained model.

ssl_model_type (str): Semisupervised model type.

text_column_name (str): Column name for the text in the dataset.

run_sup (bool): Whether to run a supervised model along with the semi supervised model for comparison.

use_sup (bool): Whether to use the trained supervised as the starting point the ssl model. If this is True, run_sup has to be true too.

remove_dirs (bool): Whether to remove dirs created by ~transformers.Trainer.

teacher_student_name (Tuple[`str, str]): A Tuple for teacher and student name, respectively for multi transformer models like MeanTeacher or NoisyStudent. Similar to model_name.

num_labels (int): Total number of classes for the classification.

unsup_hp (dict): The dictionary of all the hyperparameters in unsupervised part of model. Check the documentation of ssl_model_type and the associated trainers before setting up this dictionary. The train_with_ssl with then train the model on all the combinations of the hyperparameters set in this dictionary.

sup_stats (list): Used by the wrapper as an argument. List to save the supervised models stats for comparison, if the run_sup is turned on. If labeled_fraction is not mentioned it is created within this function.

stats (list): Used by the wrapper as an argument. List to save the chosen semisupervised models (ssl_model_type) stats. Similar to sup_stats.

l_fr(float): Used by the wrapper as an argument. Float value of fraction used as labeled dataset.

kwargs: Remaining dictionary arguments for the transformer.Trainer init function. Some of the Trainer keyword are important for training like compute metrics, and tokenizer, (see ~transformer.Trainer). If they are not mentioned, the default values would be picked, see default_args.set_default_args().

Note: ~transformer.TrainingArgument which ~transformer.Trainer accepts as the args, here could be given in same way. To distinguish arguments for the supervised trainer, it should be named args_ta_sup for the supervised trainer and args_ta for the semisupervised trainer. There are some default keys set in default_args file. If one just needs to change only some args and keep rest of them as same default, then you can also set args_ta or args_ta_sup as a dictionary.

For example setting: args_ta_sup = {learning_rate: 1} will only change learning rate to 1. Keep rest of them similar to what has been set in default_args.

Return:

sup_stats (list): If used directly without the labeled_fraction and if run_sup is True else returns empty list. Information of all the training history for supervised model.

stats (list): If used directly without the labeled_fraction. Information of all the training history for semi supervised model.

ssfinetuning.training_args.with_labeled_fraction(basefunction, labeled_fraction, *args, **kwargs)

Wrapper function around train_with_ssl implemented if a list of labeled_fractions is mentioned.

Args:

basefunction(:obj: train_with_ssl): Base function around which wrapper has been implemented.

labeled_fraction (list): List of the labeled fraction of training data to be analysed. This function uses ~dataset_utils.modify_datasets to divide the dataset into the fraction of labeled dataset and unlabeled dataset. Then, each l_fr mentioned in labeled_fraction is analysed seperately and results are stored in sup_stats and stats.

kwargs: Remaining dictionary arguments for the train_with_ssl function.

ssfinetuning.training_args.wrapper_for_l_fr(func)

Module contents