Python AI/ML Project — Gemini Configuration
Gemini CLI instructions for Python AI/ML projects focused on reproducibility and experiment rigor.
Install path
Use this file for each supported tool in your project.
- Gemini CLI: Save as
GEMINI.mdin your project atGEMINI.md.
Configuration
GEMINI.md
1# Python AI/ML Project — Gemini Configuration23PyTorch-based ML project. Config-driven experiments, reproducible training, clean data pipelines. When navigating this repo, ingest the full experiment configs and model definitions — they're the source of truth. Use the large context window to hold the entire training pipeline in memory before suggesting changes.45## Quick Reference67| Task | Command |8|---|---|9| Train | `python -m src.train --config configs/experiment.yaml` |10| Evaluate | `python -m src.evaluate --checkpoint outputs/run_name/best.pt` |11| Run tests | `pytest tests/ -x -v` |12| Lint + format | `ruff check . && ruff format .` |13| Type check | `pyright src/` or `mypy src/ --strict` |14| Launch TensorBoard | `tensorboard --logdir outputs/` |15| Data pipeline | `python -m src.data.prepare --config configs/data.yaml` |16| Export model | `python -m src.export --checkpoint outputs/run_name/best.pt --format onnx` |17| Profile | `python -m src.train --config configs/experiment.yaml --profile` |1819## Project Structure2021```22├── configs/23│ ├── experiment.yaml # Training hyperparams, model config, data config24│ ├── data.yaml # Dataset paths, preprocessing, augmentation25│ ├── model/ # Model architecture configs26│ │ ├── base.yaml27│ │ └── large.yaml28│ └── sweep/ # Hyperparameter sweep configs29│ └── lr_sweep.yaml30├── src/31│ ├── __init__.py32│ ├── train.py # Training entry point33│ ├── evaluate.py # Evaluation entry point34│ ├── export.py # Model export (ONNX, TorchScript)35│ ├── config.py # Pydantic config schemas36│ ├── models/37│ │ ├── __init__.py38│ │ ├── base.py # Abstract model interface39│ │ └── transformer.py # Concrete architecture40│ ├── data/41│ │ ├── __init__.py42│ │ ├── prepare.py # Data preprocessing pipeline43│ │ ├── dataset.py # PyTorch Dataset implementations44│ │ └── transforms.py # Data augmentation / feature transforms45│ ├── training/46│ │ ├── __init__.py47│ │ ├── trainer.py # Training loop orchestrator48│ │ ├── optimizer.py # Optimizer + scheduler factory49│ │ └── callbacks.py # Checkpointing, logging, early stopping50│ ├── metrics/51│ │ ├── __init__.py52│ │ └── core.py # Metric computation functions53│ └── utils/54│ ├── logging.py # Structured logging setup55│ ├── reproducibility.py # Seed, deterministic settings56│ └── distributed.py # Multi-GPU / DDP helpers57├── tests/58│ ├── conftest.py59│ ├── test_model.py60│ ├── test_dataset.py61│ └── test_training.py62├── notebooks/ # Exploration only — not production code63├── outputs/ # Training outputs (gitignored)64│ └── {run_name}/65│ ├── config.yaml # Frozen config for this run66│ ├── best.pt # Best checkpoint67│ ├── last.pt # Latest checkpoint68│ ├── metrics.json # Final metrics69│ └── tensorboard/ # TB event files70├── data/ # Raw/processed data (gitignored or DVC-tracked)71├── pyproject.toml72└── Makefile73```7475## Tech Stack7677| Component | Choice |78|---|---|79| Framework | PyTorch 2.x (with `torch.compile` for optimization) |80| Config | YAML files parsed via Pydantic |81| Experiment tracking | TensorBoard (or Weights & Biases if configured) |82| Data versioning | DVC (if `dvc.yaml` exists) or manual data checksums |83| Type checking | pyright or mypy --strict |84| Linting | Ruff |85| Testing | pytest |86| Distributed | `torch.distributed` / `torchrun` for multi-GPU |87| Export | ONNX / TorchScript |8889## Config-Driven Experiments9091### Pydantic Config Schema9293```python94# src/config.py95from pathlib import Path96from pydantic import BaseModel, Field, field_validator9798class DataConfig(BaseModel):99 train_path: Path100 val_path: Path101 test_path: Path | None = None102 batch_size: int = Field(32, gt=0)103 num_workers: int = Field(4, ge=0)104 max_seq_length: int = Field(512, gt=0)105 augmentation: bool = True106107class ModelConfig(BaseModel):108 name: str109 hidden_dim: int = 768110 num_layers: int = 12111 num_heads: int = 12112 dropout: float = Field(0.1, ge=0.0, lt=1.0)113 vocab_size: int = 50257114115 @field_validator("hidden_dim")116 @classmethod117 def divisible_by_heads(cls, v: int, info) -> int:118 num_heads = info.data.get("num_heads", 12)119 if v % num_heads != 0:120 raise ValueError(f"hidden_dim ({v}) must be divisible by num_heads ({num_heads})")121 return v122123class TrainingConfig(BaseModel):124 epochs: int = Field(100, gt=0)125 learning_rate: float = Field(3e-4, gt=0)126 weight_decay: float = 0.01127 warmup_steps: int = 1000128 max_grad_norm: float = 1.0129 scheduler: str = "cosine" # cosine | linear | constant130 fp16: bool = True131 compile: bool = True # torch.compile132 gradient_accumulation_steps: int = 1133 early_stopping_patience: int = 10134135class ExperimentConfig(BaseModel):136 name: str137 seed: int = 42138 data: DataConfig139 model: ModelConfig140 training: TrainingConfig141 output_dir: Path = Path("outputs")142143 @property144 def run_dir(self) -> Path:145 return self.output_dir / self.name146```147148### Loading Config From YAML149150```python151import yaml152from src.config import ExperimentConfig153154def load_config(path: str) -> ExperimentConfig:155 with open(path) as f:156 raw = yaml.safe_load(f)157 return ExperimentConfig(**raw)158```159160### Example YAML Config161162```yaml163# configs/experiment.yaml164name: "transformer-base-v3"165seed: 42166167data:168 train_path: "data/processed/train.parquet"169 val_path: "data/processed/val.parquet"170 batch_size: 64171 num_workers: 4172 max_seq_length: 512173174model:175 name: "transformer"176 hidden_dim: 768177 num_layers: 12178 num_heads: 12179 dropout: 0.1180181training:182 epochs: 50183 learning_rate: 3e-4184 weight_decay: 0.01185 warmup_steps: 2000186 scheduler: "cosine"187 fp16: true188 compile: true189 early_stopping_patience: 10190```191192## Training Loop193194```python195# src/training/trainer.py196import torch197from torch.amp import GradScaler, autocast198from pathlib import Path199from src.config import ExperimentConfig200from src.metrics.core import MetricTracker201202class Trainer:203 def __init__(204 self,205 config: ExperimentConfig,206 model: torch.nn.Module,207 optimizer: torch.optim.Optimizer,208 scheduler: torch.optim.lr_scheduler.LRScheduler,209 train_loader: torch.utils.data.DataLoader,210 val_loader: torch.utils.data.DataLoader,211 ) -> None:212 self.config = config213 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")214 self.model = model.to(self.device)215 self.optimizer = optimizer216 self.scheduler = scheduler217 self.train_loader = train_loader218 self.val_loader = val_loader219 self.scaler = GradScaler(enabled=config.training.fp16)220 self.best_val_loss = float("inf")221 self.patience_counter = 0222 self.metrics = MetricTracker()223224 if config.training.compile:225 self.model = torch.compile(self.model)226227 def train(self) -> dict:228 for epoch in range(self.config.training.epochs):229 train_loss = self._train_epoch(epoch)230 val_loss, val_metrics = self._validate(epoch)231232 self.scheduler.step()233 self._log_epoch(epoch, train_loss, val_loss, val_metrics)234235 if val_loss < self.best_val_loss:236 self.best_val_loss = val_loss237 self.patience_counter = 0238 self._save_checkpoint(epoch, is_best=True)239 else:240 self.patience_counter += 1241242 self._save_checkpoint(epoch, is_best=False)243244 if self.patience_counter >= self.config.training.early_stopping_patience:245 logger.info(f"Early stopping at epoch {epoch}")246 break247248 return self.metrics.summary()249250 def _train_epoch(self, epoch: int) -> float:251 self.model.train()252 total_loss = 0.0253254 for step, batch in enumerate(self.train_loader):255 batch = {k: v.to(self.device) for k, v in batch.items()}256257 with autocast(device_type="cuda", enabled=self.config.training.fp16):258 loss = self.model(**batch).loss / self.config.training.gradient_accumulation_steps259260 self.scaler.scale(loss).backward()261262 if (step + 1) % self.config.training.gradient_accumulation_steps == 0:263 self.scaler.unscale_(self.optimizer)264 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)265 self.scaler.step(self.optimizer)266 self.scaler.update()267 self.optimizer.zero_grad(set_to_none=True)268269 total_loss += loss.item() * self.config.training.gradient_accumulation_steps270271 return total_loss / len(self.train_loader)272273 @torch.no_grad()274 def _validate(self, epoch: int) -> tuple[float, dict]:275 self.model.eval()276 total_loss = 0.0277 self.metrics.reset()278279 for batch in self.val_loader:280 batch = {k: v.to(self.device) for k, v in batch.items()}281 with autocast(device_type="cuda", enabled=self.config.training.fp16):282 outputs = self.model(**batch)283 total_loss += outputs.loss.item()284 self.metrics.update(outputs.logits, batch["labels"])285286 return total_loss / len(self.val_loader), self.metrics.compute()287288 def _save_checkpoint(self, epoch: int, is_best: bool) -> None:289 run_dir = self.config.run_dir290 run_dir.mkdir(parents=True, exist_ok=True)291 state = {292 "epoch": epoch,293 "model_state_dict": self.model.state_dict(),294 "optimizer_state_dict": self.optimizer.state_dict(),295 "scheduler_state_dict": self.scheduler.state_dict(),296 "best_val_loss": self.best_val_loss,297 "config": self.config.model_dump(),298 }299 torch.save(state, run_dir / "last.pt")300 if is_best:301 torch.save(state, run_dir / "best.pt")302```303304## Dataset Pattern305306```python307# src/data/dataset.py308import torch309from torch.utils.data import Dataset310import pandas as pd311from pathlib import Path312313class TextDataset(Dataset):314 def __init__(315 self,316 data_path: Path,317 tokenizer: "Tokenizer",318 max_length: int = 512,319 ) -> None:320 self.data = pd.read_parquet(data_path)321 self.tokenizer = tokenizer322 self.max_length = max_length323324 def __len__(self) -> int:325 return len(self.data)326327 def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:328 row = self.data.iloc[idx]329 encoding = self.tokenizer(330 row["text"],331 max_length=self.max_length,332 padding="max_length",333 truncation=True,334 return_tensors="pt",335 )336 return {337 "input_ids": encoding["input_ids"].squeeze(0),338 "attention_mask": encoding["attention_mask"].squeeze(0),339 "labels": torch.tensor(row["label"], dtype=torch.long),340 }341```342343## Reproducibility344345```python346# src/utils/reproducibility.py347import torch348import random349import numpy as np350351def set_seed(seed: int) -> None:352 random.seed(seed)353 np.random.seed(seed)354 torch.manual_seed(seed)355 torch.cuda.manual_seed_all(seed)356357def set_deterministic(enabled: bool = True) -> None:358 torch.backends.cudnn.deterministic = enabled359 torch.backends.cudnn.benchmark = not enabled360 torch.use_deterministic_algorithms(enabled, warn_only=True)361```362363Always call `set_seed(config.seed)` before any model initialization or data loading.364365## Testing ML Code366367### Test Model Forward Pass368369```python370# tests/test_model.py371import pytest372import torch373from src.models.transformer import TransformerModel374from src.config import ModelConfig375376@pytest.fixture377def model_config() -> ModelConfig:378 return ModelConfig(379 name="transformer",380 hidden_dim=64, # Small for tests381 num_layers=2,382 num_heads=4,383 dropout=0.0, # Deterministic for tests384 vocab_size=1000,385 )386387def test_forward_output_shape(model_config: ModelConfig) -> None:388 model = TransformerModel(model_config)389 batch_size, seq_len = 4, 32390391 input_ids = torch.randint(0, model_config.vocab_size, (batch_size, seq_len))392 attention_mask = torch.ones(batch_size, seq_len)393394 output = model(input_ids=input_ids, attention_mask=attention_mask)395396 assert output.logits.shape == (batch_size, seq_len, model_config.vocab_size)397398def test_model_trains_on_small_batch(model_config: ModelConfig) -> None:399 """Verify the model can overfit a tiny batch — sanity check for training."""400 model = TransformerModel(model_config)401 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)402403 input_ids = torch.randint(0, model_config.vocab_size, (2, 16))404 labels = torch.randint(0, model_config.vocab_size, (2, 16))405406 initial_loss = None407 for _ in range(20):408 output = model(input_ids=input_ids, labels=labels)409 loss = output.loss410 if initial_loss is None:411 initial_loss = loss.item()412 optimizer.zero_grad()413 loss.backward()414 optimizer.step()415416 assert loss.item() < initial_loss * 0.5, "Model should overfit a tiny batch"417```418419### Test Data Pipeline420421```python422# tests/test_dataset.py423def test_dataset_length(sample_dataset: TextDataset) -> None:424 assert len(sample_dataset) == 3425426def test_dataset_item_shape(sample_dataset: TextDataset) -> None:427 item = sample_dataset[0]428 assert item["input_ids"].shape == (32,)429 assert item["attention_mask"].shape == (32,)430 assert "labels" in item431```432433## Conventions434435### Code Style436437- Full type annotations everywhere. `mypy --strict` or `pyright` must pass.438- Use `pathlib.Path` for all file paths — never string concatenation.439- Prefer `torch.Tensor` type hints. Use `torch.no_grad()` as a decorator for eval functions.440- Device handling: accept device as a parameter or resolve once in Trainer, not scattered throughout.441- Use `model.train()` / `model.eval()` explicitly — never assume the mode.442443### Experiment Tracking444445- Every run saves a frozen `config.yaml` to its output directory.446- Log metrics as structured data (JSON or TensorBoard scalars), not print statements.447- Checkpoint both `best.pt` (best validation) and `last.pt` (resume-from).448- Never hardcode hyperparameters in code — always pull from config.449450### GPU / Memory451452- Use `torch.amp.autocast` for mixed precision — don't manually cast tensors.453- Use `optimizer.zero_grad(set_to_none=True)` — saves memory.454- Free unused tensors: don't accumulate `.item()` calls in loops without detaching.455- Profile first, optimize second: `torch.profiler` or `--profile` flag.456457## Gemini-Specific Guidance458459When working in this repo:460461- **Ingest full configs before suggesting changes.** Read the experiment YAML and `config.py` together to understand all hyperparameters and their constraints.462- **Read the full model definition.** Architecture changes need full context — partial reads lead to shape mismatches.463- **Check data pipeline end-to-end** before modifying transforms or dataset classes. Data shape changes propagate to the model.464- **Reference training outputs** (`outputs/*/metrics.json`) to understand what's been tried and what the current baseline is.465- **For debugging training issues,** read the full training loop (trainer.py) plus the config — most bugs are config/shape mismatches, not logic errors.466
Community feedback
0 found this helpful
Works with: