Overview

T5-Base v1.1 model trained to generate hypotheses given a premise and a label. Below the settings used to train it

                                                                                                                  
Experiment configurations                                                                                         
├── datasets                                                                                                      
│   └── mnli_train:                                                                                               
│         dataset_name: multi_nli                                                                                 
│         dataset_config_name: null                                                                               
│         cache_dir: null                                                                                         
│         input_fields:                                                                                           
│         - premise                                                                                               
│         - hypothesis                                                                                            
│         target_field: label                                                                                     
│         train_subset_names: null                                                                                
│         val_subset_names: validation_matched                                                                    
│         test_subset_names: none                                                                                 
│         train_val_split: null                                                                                   
│         limit_train_samples: null                                                                               
│         limit_val_samples: null                                                                                 
│         limit_test_samples: null                                                                                
│         sampling_kwargs:                                                                                        
│           sampling_strategy: random                                                                             
│           seed: 42                                                                                              
│           replace: false                                                                                        
│         align_labels_with_mapping: null                                                                         
│         avoid_consistency_check: false                                                                          
│         predict_label_mapping: null                                                                             
│       mnli:                                                                                                     
│         dataset_name: multi_nli                                                                                 
│         dataset_config_name: null                                                                               
│         cache_dir: null                                                                                         
│         input_fields:                                                                                           
│         - premise                                                                                               
│         - hypothesis                                                                                            
│         target_field: label                                                                                     
│         train_subset_names: none                                                                                
│         val_subset_names: none                                                                                  
│         test_subset_names: validation_mismatched                                                                
│         train_val_split: null                                                                                   
│         limit_train_samples: null                                                                               
│         limit_val_samples: null                                                                                 
│         limit_test_samples: null                                                                                
│         sampling_kwargs:                                                                                        
│           sampling_strategy: random                                                                             
│           seed: 42                                                                                              
│           replace: false                                                                                        
│         align_labels_with_mapping: null                                                                         
│         avoid_consistency_check: false                                                                          
│         predict_label_mapping: null                                                                             
│                                                                                                                 
├── data                                                                                                          
│   └── _target_: src.task.nli.data.NLIGenerationData.from_config                                                 
│       main_dataset_name: null                                                                                   
│       use_additional_as_test: null                                                                              
│       dataloader:                                                                                               
│         batch_size: 64                                                                                          
│         eval_batch_size: 100                                                                                    
│         num_workers: 16                                                                                         
│         pin_memory: true                                                                                        
│         drop_last: false                                                                                        
│         persistent_workers: false                                                                               
│         shuffle: true                                                                                           
│         seed_dataloader: 42                                                                                     
│         replacement: false                                                                                      
│       processing:                                                                                               
│         preprocessing_num_workers: 16                                                                           
│         preprocessing_batch_size: 1000                                                                          
│         load_from_cache_file: true                                                                              
│         padding: longest                                                                                        
│         truncation: longest_first                                                                               
│         max_source_length: 128                                                                                  
│         max_target_length: 128                                                                                  
│         template: 'premise: $premise $label hypothesis: '                                                       
│       tokenizer:                                                                                                
│         _target_: transformers.AutoTokenizer.from_pretrained                                                    
│         pretrained_model_name_or_path: google/t5-v1_1-base                                                      
│         use_fast: true                                                                                          
│                                                                                                                 
├── task                                                                                                          
│   └── optimizer:                                                                                                
│         name: Adafactor                                                                                         
│         lr: 0.001                                                                                               
│         weight_decay: 0.0                                                                                       
│         no_decay:                                                                                               
│         - bias                                                                                                  
│         - LayerNorm.weight                                                                                      
│         decay_rate: -0.8                                                                                        
│         clip_threshold: 1.0                                                                                     
│         relative_step: false                                                                                    
│         scale_parameter: false                                                                                  
│         warmup_init: false                                                                                      
│       scheduler:                                                                                                
│         name: constant_schedule                                                                                 
│       model:                                                                                                    
│         model_name_or_path: google/t5-v1_1-base                                                                 
│       checkpoint_path: null                                                                                     
│       freeze: false                                                                                             
│       seed_init_weight: 42                                                                                      
│       _target_: src.task.nli.NLIGenerationTask.from_config                                                      
│       generation:                                                                                               
│         max_length: 128                                                                                         
│         min_length: 3                                                                                           
│         do_sample: true                                                                                         
│         early_stopping: false                                                                                   
│         num_beams: 1                                                                                            
│         temperature: 1.0                                                                                        
│         top_k: 50                                                                                               
│         top_p: 0.95                                                                                             
│         repetition_penalty: null                                                                                
│         length_penalty: null                                                                                    
│         no_repeat_ngram_size: null                                                                              
│         encoder_no_repeat_ngram_size: null                                                                      
│         num_return_sequences: 1                                                                                 
│         max_time: null                                                                                          
│         max_new_tokens: null                                                                                    
│         decoder_start_token_id: null                                                                            
│         use_cache: null                                                                                         
│         num_beam_groups: null                                                                                   
│         diversity_penalty: null                                                                                 
│                                                                                                                 
├── trainer                                                                                                       
│   └── _target_: pytorch_lightning.Trainer                                                                       
│       callbacks:                                                                                                
│         lr_monitor:                                                                                             
│           _target_: pytorch_lightning.callbacks.LearningRateMonitor                                             
│           logging_interval: step                                                                                
│           log_momentum: false                                                                                   
│         model_checkpoint:                                                                                       
│           _target_: pytorch_lightning.callbacks.ModelCheckpoint                                                 
│           dirpath: ./checkpoints/                                                                               
│           filename: nli_generator_mnli-epoch={epoch:02d}-val_loss={val/aggregated_loss:.2f}                     
│           monitor: val/aggregated_loss                                                                          
│           mode: min                                                                                             
│           verbose: false                                                                                        
│           save_last: true                                                                                       
│           save_top_k: 1                                                                                         
│           auto_insert_metric_name: false                                                                        
│           save_on_train_epoch_end: false                                                                        
│         rich_model_summary:                                                                                     
│           _target_: pytorch_lightning.callbacks.RichModelSummary                                                
│           max_depth: 1                                                                                          
│         log_grad_norm:                                                                                          
│           _target_: src.core.callbacks.LogGradNorm                                                              
│           norm_type: 2                                                                                          
│           group_separator: /                                                                                    
│           only_total: true                                                                                      
│           on_step: true                                                                                         
│           on_epoch: false                                                                                       
│           prog_bar: true                                                                                        
│         log_generated_text:                                                                                     
│           _target_: src.core.callbacks.GenerateAndLogText                                                       
│           dirpath: ./generated_text                                                                             
│           type: generated_text                                                                                  
│           pop_keys_after_logging: true                                                                          
│           on_train: false                                                                                       
│           on_validation: false                                                                                  
│           on_test: true                                                                                         
│           log_to_wandb: true                                                                                    
│         wandb_log_dataset_sizes:                                                                                
│           _target_: src.core.callbacks.WandbLogDatasetSizes                                                     
│       logger:                                                                                                   
│         wandb:                                                                                                  
│           _target_: pytorch_lightning.loggers.WandbLogger                                                       
│           project: nli_debiasing                                                                                
│           entity: team_brushino                                                                                 
│           name: nli_generator_mnli                                                                              
│           save_dir: ./                                                                                          
│           offline: false                                                                                        
│           log_model: false                                                                                      
│           group: mnli                                                                                           
│           job_type: generator                                                                                   
│           tags:                                                                                                 
│           - nli_generator_mnli                                                                                  
│           - seed=42                                                                                             
│           - seed_dataloader=42                                                                                  
│           notes: nli_generator_mnli_time=02-24-53                                                               
│       enable_checkpointing: true                                                                                
│       enable_progress_bar: true                                                                                 
│       enable_model_summary: true                                                                                
│       gradient_clip_val: 0.0                                                                                    
│       gradient_clip_algorithm: null                                                                             
│       accelerator: gpu                                                                                          
│       devices: auto                                                                                             
│       gpus: null                                                                                                
│       auto_select_gpus: true                                                                                    
│       accumulate_grad_batches: 1                                                                                
│       max_epochs: 3                                                                                             
│       min_epochs: 1                                                                                             
│       max_steps: -1                                                                                             
│       min_steps: null                                                                                           
│       max_time: null                                                                                            
│       num_sanity_val_steps: 2                                                                                   
│       overfit_batches: 0.0                                                                                      
│       fast_dev_run: false                                                                                       
│       limit_train_batches: 1.0                                                                                  
│       limit_val_batches: 1.0                                                                                    
│       limit_test_batches: 1.0                                                                                   
│       profiler: null                                                                                            
│       detect_anomaly: false                                                                                     
│       deterministic: false                                                                                      
│       check_val_every_n_epoch: 1                                                                                
│       val_check_interval: 0.1                                                                                   
│       log_every_n_steps: 10                                                                                     
│       move_metrics_to_cpu: false                                                                                
│                                                                                                                 
└── training                                                                                                      
    └── run_val_before_fit: false                                                                                 
        run_val_after_fit: false                                                                                  
        run_test_before_fit: false                                                                                
        run_test_after_fit: true                                                                                  
        lr: 0.001                                                                                                 
        seed: 42                                                                                                  
        show_batch: false                                                                                         
        batch_size: 64                                                                                            
        eval_batch_size: 100                                                                                      
        num_workers: 16                                                                                           
        pin_memory: true                                                                                          
        drop_last: false                                                                                          
        persistent_workers: false                                                                                 
        shuffle: true                                                                                             
        seed_dataloader: 42                                                                                       
        ignore_warnings: true                                                                                     
        experiment_name: nli_generator_mnli