GitHunt
NI

nitinvetcha/DeGAML-LLM

DeGAML-LLM: Decoupling Generalization and Adaptation in Meta-Learning for Large Language Models

DeGAML-LLM: Decoupled Generalization and Adaptation Meta-Learning for Large Language Models

DeGAML-LLM Architecture

Project Page
License
Python 3.12+
HuggingFace

๐Ÿ“‘ Contents


๐Ÿ“‹ Overview

DeGAML-LLM introduces a novel meta-learning framework that explicitly decouples generalization and adaptation for large language models, addressing fundamental limitations in existing approaches like MAML-en-LLM and ABMLL.

Key Innovation

Traditional meta-learning for LLMs couples two distinct objectives:

  1. Generalization: Learning task-agnostic representations across task distributions
  2. Adaptation: Enabling rapid task-specific refinement

DeGAML-LLM separates these through dedicated modules operating in distinct parameter spaces:

  • ๐Ÿ”ฎ Generalization Module ($\mathcal{G}_\phi$): Learns to generate LoRA adapter parameters from task prompts using a hyperconvolutional decoder trained on checkpoint trajectories
  • โšก Adaptation Module ($\pi_\psi$): Refines generated parameters via an RL policy that selects from four adaptation families (TTT, TTS, LoRA Mixing, Latent Space)

Critical Design: Gradients from adaptation do not flow back to the generalization module, ensuring true decoupling.

Performance Highlights

โœจ State-of-the-art results on common-sense reasoning, mathematics, logic, social, medical, and coding benchmarks
๐Ÿš€ Outperforms MAML-en-LLM, ABMLL, and standard multi-task baselines
โš™๏ธ Flexible adaptation via four distinct adaptation families with automatic strategy selection
๐ŸŽฏ Strong generalization to out-of-domain tasks without task-specific fine-tuning


๐Ÿ—๏ธ Architecture

DeGAML-LLM consists of two key components trained sequentially:

Architecture Diagram

1. Generalization Module

  • Input: Task prompts (unlabeled examples from test set)
  • Output: Distribution over LoRA adapter parameters
  • Training: Offline via MSE loss on collected LoRA checkpoints (no adaptation)

2. Adaptation Module

  • Input: Generated adapter parameters + validation performance
  • Output: Adaptation strategy selection and refinement
  • Adaptation Families:
    • TTT (Test-Time Training): Fine-tune adapters on unlabeled test data via perplexity minimization
    • TTS (Test-Time Scaling): Ensemble multiple adapters via max-confidence or majority vote
    • LoRA Mixing: Interpolate LoRA subspaces using two-subspace (TS) mixing
    • Latent Space: Optimize SLOT vectors (sample-specific latent parameters)
  • Training: Online via ReST^EM with frozen generator (gradients detached)

๐Ÿš€ Quick Start

Installation

# Clone the repository
git clone https://github.com/YOUR_USERNAME/DeGAML-LLM.git
cd DeGAML-LLM

# Create environment and install dependencies
conda create -n degaml python=3.12
conda activate degaml
pip install -r requirements.txt

Environment Setup

Configure paths via environment variables (optional):

export DEGAML_DATA_ROOT="./data"
export DEGAML_OUTPUT_ROOT="./outputs"  
export DEGAML_CHECKPOINT_ROOT="./checkpoints"
export DEGAML_MODEL_ROOT="./models"

Download Models

# Download base LLM (Qwen2.5-0.5B or 1.5B)
huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir ./models/Qwen2.5-0.5B-Instruct

# Download Sentence-BERT encoder
huggingface-cli download sentence-transformers/all-MiniLM-L12-v2 --local-dir ./models/all-MiniLM-L12-v2

Basic Usage

1. Baseline Evaluation (No Adaptation)

Generate adapters directly from task prompts:

python -m degaml.core.baseline \
    --eval_dataset ARC-c \
    --test_dataset ARC-c \
    --num_samples 25

2. Generate Hypotheses

Use the RL policy to propose adaptation strategies:

python -m degaml.core.hypothesis_generation \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --lora_adapter_path ./checkpoints/policy_adapter \
    --num_generations 20 \
    --output_file ./outputs/hypotheses.txt

3. Run Adaptation

Execute adaptation strategies (example with TTT):

python -m degaml.adaptation.test_time_training \
    --eval_dataset ARC-c \
    --test_dataset ARC-c \
    --ttl_steps 5 \
    --learning_rate 1e-5 \
    --batch_size 4

๐Ÿ“Š Experimental Results

In-Domain Tasks (Common-Sense Reasoning)

Method ARC-c ARC-e HellaSwag BoolQ PIQA WinoGrande Avg
No Meta-Train LoRA 74.5 84.4 55.8 55.6 65.6 48.2 64.0
Union Train LoRA 63.2 73.9 48.9 55.1 47.8 61.3 58.3
ABMLL 69.9 83.2 51.1 63.2 54.3 52.9 62.4
MAML-en-LLM 66.0 84.3 59.3 58.7 68.1 56.8 65.5
DeGAML-LLM 73.7 88.4 57.2 58.8 70.7 57.3 67.7

Out-of-Domain Tasks

Method GSM-8K MATH DivLogicEval SocialIQA CodeMMLU JAMA Avg
Union Train LoRA 34.2 32.2 24.1 51.4 34.7 34.7 36.1
ABMLL 28.7 15.9 26.9 66.3 39.6 28.5 34.3
MAML-en-LLM 35.6 43.5 31.2 68.7 42.3 32.5 42.3
DeGAML-LLM 51.4 46.9 31.4 69.5 44.6 41.5 47.5

Note: Results with Qwen2.5-1.5B-Instruct. See paper for complete results across model scales.


๐Ÿ“š Repository Structure

DeGAML-LLM/
โ”œโ”€โ”€ degaml/                        
โ”‚   โ”œโ”€โ”€ core/                      
โ”‚   โ”‚   โ”œโ”€โ”€ baseline.py            
โ”‚   โ”‚   โ”œโ”€โ”€ hypothesis_generation.py  
โ”‚   โ”‚   โ”œโ”€โ”€ accuracy.py            
โ”‚   โ”‚   โ””โ”€โ”€ mega.py                # Pipeline orchestrator
โ”‚   โ”œโ”€โ”€ adaptation/                
โ”‚   โ”‚   โ”œโ”€โ”€ test_time_training.py  
โ”‚   โ”‚   โ”œโ”€โ”€ test_time_scaling.py   
โ”‚   โ”‚   โ”œโ”€โ”€ lora_mixing.py         
โ”‚   โ”‚   โ””โ”€โ”€ latent_space.py        
โ”‚   โ”œโ”€โ”€ generator/                 
โ”‚   โ”‚   โ”œโ”€โ”€ dataset/              
โ”‚   โ”‚   โ”œโ”€โ”€ model/                
โ”‚   โ”‚   โ”œโ”€โ”€ module/               
โ”‚   โ”‚   โ”œโ”€โ”€ tokenizer/            
โ”‚   โ”‚   โ””โ”€โ”€ tools/                
โ”‚   โ”œโ”€โ”€ policy/                    
โ”‚   โ”œโ”€โ”€ utils/                     
โ”‚   โ”‚   โ”œโ”€โ”€ paths.py              
โ”‚   โ”‚   โ””โ”€โ”€ config.py            
โ”‚   โ””โ”€โ”€ ablation/                  
โ”œโ”€โ”€ configs/                       
โ”œโ”€โ”€ docs/                          
โ”œโ”€โ”€ scripts/                       
โ”œโ”€โ”€ assets/                        
โ””โ”€โ”€ requirements.txt               

๐Ÿ”ง Advanced Usage

Running Ablation Studies

Isolate contributions of individual adaptation families:

python -m degaml.ablation.ablation_runner \
    --eval_dataset ARC-c \
    --test_dataset ARC-c \
    --family TTT \
    --num_samples 25 \
    --iterations 1

Training the Parameter Generator

The parameter generator uses a hyperconvolutional decoder architecture that is self-contained in this repository. Key steps:

  1. Collect LoRA checkpoints across meta-training tasks
  2. Calculate importance scores for parameter tokenization
  3. Train hyperconvolutional decoder via MSE loss

Training scripts and detailed instructions will be provided in future releases.

Pre-trained LoRA checkpoints are available on HuggingFace: Nitin2004/DeGAML-LLM-checkpoints

Download checkpoints using:

from huggingface_hub import hf_hub_download

checkpoint = hf_hub_download(
    repo_id="Nitin2004/DeGAML-LLM-checkpoints",
    filename="qwen0.5lora__ARC-c.pth"
)

Training the RL Policy

python -m degaml.policy.train_policy \
    --meta_train_tasks "ARC-c,HellaSwag,BoolQ" \
    --num_iterations 10 \
    --reward_type accuracy_improvement

๐Ÿ“– Documentation


๐Ÿ“„ License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.


๐Ÿค Contributing

We welcome contributions! Please see our contributing guidelines for more information.

๐Ÿ“ง Contact

For questions and feedback, please open an issue or contact nitinvetcha@gmail.com


Star โญ this repository if you find it helpful!

Made with โค๏ธ for advancing meta-learning in LLMs

nitinvetcha/DeGAML-LLM | GitHunt