dotmd

Python AI/ML Project — Gemini Configuration

Gemini CLI instructions for Python AI/ML projects focused on reproducibility and experiment rigor.

By dotmd TeamCC0Published Feb 19, 2026View source ↗

Install path

Use this file for each supported tool in your project.

  • Gemini CLI: Save as GEMINI.md in your project at GEMINI.md.

Configuration

GEMINI.md

1# Python AI/ML Project — Gemini Configuration
2
3PyTorch-based ML project. Config-driven experiments, reproducible training, clean data pipelines. When navigating this repo, ingest the full experiment configs and model definitions — they're the source of truth. Use the large context window to hold the entire training pipeline in memory before suggesting changes.
4
5## Quick Reference
6
7| Task | Command |
8|---|---|
9| Train | `python -m src.train --config configs/experiment.yaml` |
10| Evaluate | `python -m src.evaluate --checkpoint outputs/run_name/best.pt` |
11| Run tests | `pytest tests/ -x -v` |
12| Lint + format | `ruff check . && ruff format .` |
13| Type check | `pyright src/` or `mypy src/ --strict` |
14| Launch TensorBoard | `tensorboard --logdir outputs/` |
15| Data pipeline | `python -m src.data.prepare --config configs/data.yaml` |
16| Export model | `python -m src.export --checkpoint outputs/run_name/best.pt --format onnx` |
17| Profile | `python -m src.train --config configs/experiment.yaml --profile` |
18
19## Project Structure
20
21```
22├── configs/
23│ ├── experiment.yaml # Training hyperparams, model config, data config
24│ ├── data.yaml # Dataset paths, preprocessing, augmentation
25│ ├── model/ # Model architecture configs
26│ │ ├── base.yaml
27│ │ └── large.yaml
28│ └── sweep/ # Hyperparameter sweep configs
29│ └── lr_sweep.yaml
30├── src/
31│ ├── __init__.py
32│ ├── train.py # Training entry point
33│ ├── evaluate.py # Evaluation entry point
34│ ├── export.py # Model export (ONNX, TorchScript)
35│ ├── config.py # Pydantic config schemas
36│ ├── models/
37│ │ ├── __init__.py
38│ │ ├── base.py # Abstract model interface
39│ │ └── transformer.py # Concrete architecture
40│ ├── data/
41│ │ ├── __init__.py
42│ │ ├── prepare.py # Data preprocessing pipeline
43│ │ ├── dataset.py # PyTorch Dataset implementations
44│ │ └── transforms.py # Data augmentation / feature transforms
45│ ├── training/
46│ │ ├── __init__.py
47│ │ ├── trainer.py # Training loop orchestrator
48│ │ ├── optimizer.py # Optimizer + scheduler factory
49│ │ └── callbacks.py # Checkpointing, logging, early stopping
50│ ├── metrics/
51│ │ ├── __init__.py
52│ │ └── core.py # Metric computation functions
53│ └── utils/
54│ ├── logging.py # Structured logging setup
55│ ├── reproducibility.py # Seed, deterministic settings
56│ └── distributed.py # Multi-GPU / DDP helpers
57├── tests/
58│ ├── conftest.py
59│ ├── test_model.py
60│ ├── test_dataset.py
61│ └── test_training.py
62├── notebooks/ # Exploration only — not production code
63├── outputs/ # Training outputs (gitignored)
64│ └── {run_name}/
65│ ├── config.yaml # Frozen config for this run
66│ ├── best.pt # Best checkpoint
67│ ├── last.pt # Latest checkpoint
68│ ├── metrics.json # Final metrics
69│ └── tensorboard/ # TB event files
70├── data/ # Raw/processed data (gitignored or DVC-tracked)
71├── pyproject.toml
72└── Makefile
73```
74
75## Tech Stack
76
77| Component | Choice |
78|---|---|
79| Framework | PyTorch 2.x (with `torch.compile` for optimization) |
80| Config | YAML files parsed via Pydantic |
81| Experiment tracking | TensorBoard (or Weights & Biases if configured) |
82| Data versioning | DVC (if `dvc.yaml` exists) or manual data checksums |
83| Type checking | pyright or mypy --strict |
84| Linting | Ruff |
85| Testing | pytest |
86| Distributed | `torch.distributed` / `torchrun` for multi-GPU |
87| Export | ONNX / TorchScript |
88
89## Config-Driven Experiments
90
91### Pydantic Config Schema
92
93```python
94# src/config.py
95from pathlib import Path
96from pydantic import BaseModel, Field, field_validator
97
98class DataConfig(BaseModel):
99 train_path: Path
100 val_path: Path
101 test_path: Path | None = None
102 batch_size: int = Field(32, gt=0)
103 num_workers: int = Field(4, ge=0)
104 max_seq_length: int = Field(512, gt=0)
105 augmentation: bool = True
106
107class ModelConfig(BaseModel):
108 name: str
109 hidden_dim: int = 768
110 num_layers: int = 12
111 num_heads: int = 12
112 dropout: float = Field(0.1, ge=0.0, lt=1.0)
113 vocab_size: int = 50257
114
115 @field_validator("hidden_dim")
116 @classmethod
117 def divisible_by_heads(cls, v: int, info) -> int:
118 num_heads = info.data.get("num_heads", 12)
119 if v % num_heads != 0:
120 raise ValueError(f"hidden_dim ({v}) must be divisible by num_heads ({num_heads})")
121 return v
122
123class TrainingConfig(BaseModel):
124 epochs: int = Field(100, gt=0)
125 learning_rate: float = Field(3e-4, gt=0)
126 weight_decay: float = 0.01
127 warmup_steps: int = 1000
128 max_grad_norm: float = 1.0
129 scheduler: str = "cosine" # cosine | linear | constant
130 fp16: bool = True
131 compile: bool = True # torch.compile
132 gradient_accumulation_steps: int = 1
133 early_stopping_patience: int = 10
134
135class ExperimentConfig(BaseModel):
136 name: str
137 seed: int = 42
138 data: DataConfig
139 model: ModelConfig
140 training: TrainingConfig
141 output_dir: Path = Path("outputs")
142
143 @property
144 def run_dir(self) -> Path:
145 return self.output_dir / self.name
146```
147
148### Loading Config From YAML
149
150```python
151import yaml
152from src.config import ExperimentConfig
153
154def load_config(path: str) -> ExperimentConfig:
155 with open(path) as f:
156 raw = yaml.safe_load(f)
157 return ExperimentConfig(**raw)
158```
159
160### Example YAML Config
161
162```yaml
163# configs/experiment.yaml
164name: "transformer-base-v3"
165seed: 42
166
167data:
168 train_path: "data/processed/train.parquet"
169 val_path: "data/processed/val.parquet"
170 batch_size: 64
171 num_workers: 4
172 max_seq_length: 512
173
174model:
175 name: "transformer"
176 hidden_dim: 768
177 num_layers: 12
178 num_heads: 12
179 dropout: 0.1
180
181training:
182 epochs: 50
183 learning_rate: 3e-4
184 weight_decay: 0.01
185 warmup_steps: 2000
186 scheduler: "cosine"
187 fp16: true
188 compile: true
189 early_stopping_patience: 10
190```
191
192## Training Loop
193
194```python
195# src/training/trainer.py
196import torch
197from torch.amp import GradScaler, autocast
198from pathlib import Path
199from src.config import ExperimentConfig
200from src.metrics.core import MetricTracker
201
202class Trainer:
203 def __init__(
204 self,
205 config: ExperimentConfig,
206 model: torch.nn.Module,
207 optimizer: torch.optim.Optimizer,
208 scheduler: torch.optim.lr_scheduler.LRScheduler,
209 train_loader: torch.utils.data.DataLoader,
210 val_loader: torch.utils.data.DataLoader,
211 ) -> None:
212 self.config = config
213 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
214 self.model = model.to(self.device)
215 self.optimizer = optimizer
216 self.scheduler = scheduler
217 self.train_loader = train_loader
218 self.val_loader = val_loader
219 self.scaler = GradScaler(enabled=config.training.fp16)
220 self.best_val_loss = float("inf")
221 self.patience_counter = 0
222 self.metrics = MetricTracker()
223
224 if config.training.compile:
225 self.model = torch.compile(self.model)
226
227 def train(self) -> dict:
228 for epoch in range(self.config.training.epochs):
229 train_loss = self._train_epoch(epoch)
230 val_loss, val_metrics = self._validate(epoch)
231
232 self.scheduler.step()
233 self._log_epoch(epoch, train_loss, val_loss, val_metrics)
234
235 if val_loss < self.best_val_loss:
236 self.best_val_loss = val_loss
237 self.patience_counter = 0
238 self._save_checkpoint(epoch, is_best=True)
239 else:
240 self.patience_counter += 1
241
242 self._save_checkpoint(epoch, is_best=False)
243
244 if self.patience_counter >= self.config.training.early_stopping_patience:
245 logger.info(f"Early stopping at epoch {epoch}")
246 break
247
248 return self.metrics.summary()
249
250 def _train_epoch(self, epoch: int) -> float:
251 self.model.train()
252 total_loss = 0.0
253
254 for step, batch in enumerate(self.train_loader):
255 batch = {k: v.to(self.device) for k, v in batch.items()}
256
257 with autocast(device_type="cuda", enabled=self.config.training.fp16):
258 loss = self.model(**batch).loss / self.config.training.gradient_accumulation_steps
259
260 self.scaler.scale(loss).backward()
261
262 if (step + 1) % self.config.training.gradient_accumulation_steps == 0:
263 self.scaler.unscale_(self.optimizer)
264 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)
265 self.scaler.step(self.optimizer)
266 self.scaler.update()
267 self.optimizer.zero_grad(set_to_none=True)
268
269 total_loss += loss.item() * self.config.training.gradient_accumulation_steps
270
271 return total_loss / len(self.train_loader)
272
273 @torch.no_grad()
274 def _validate(self, epoch: int) -> tuple[float, dict]:
275 self.model.eval()
276 total_loss = 0.0
277 self.metrics.reset()
278
279 for batch in self.val_loader:
280 batch = {k: v.to(self.device) for k, v in batch.items()}
281 with autocast(device_type="cuda", enabled=self.config.training.fp16):
282 outputs = self.model(**batch)
283 total_loss += outputs.loss.item()
284 self.metrics.update(outputs.logits, batch["labels"])
285
286 return total_loss / len(self.val_loader), self.metrics.compute()
287
288 def _save_checkpoint(self, epoch: int, is_best: bool) -> None:
289 run_dir = self.config.run_dir
290 run_dir.mkdir(parents=True, exist_ok=True)
291 state = {
292 "epoch": epoch,
293 "model_state_dict": self.model.state_dict(),
294 "optimizer_state_dict": self.optimizer.state_dict(),
295 "scheduler_state_dict": self.scheduler.state_dict(),
296 "best_val_loss": self.best_val_loss,
297 "config": self.config.model_dump(),
298 }
299 torch.save(state, run_dir / "last.pt")
300 if is_best:
301 torch.save(state, run_dir / "best.pt")
302```
303
304## Dataset Pattern
305
306```python
307# src/data/dataset.py
308import torch
309from torch.utils.data import Dataset
310import pandas as pd
311from pathlib import Path
312
313class TextDataset(Dataset):
314 def __init__(
315 self,
316 data_path: Path,
317 tokenizer: "Tokenizer",
318 max_length: int = 512,
319 ) -> None:
320 self.data = pd.read_parquet(data_path)
321 self.tokenizer = tokenizer
322 self.max_length = max_length
323
324 def __len__(self) -> int:
325 return len(self.data)
326
327 def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
328 row = self.data.iloc[idx]
329 encoding = self.tokenizer(
330 row["text"],
331 max_length=self.max_length,
332 padding="max_length",
333 truncation=True,
334 return_tensors="pt",
335 )
336 return {
337 "input_ids": encoding["input_ids"].squeeze(0),
338 "attention_mask": encoding["attention_mask"].squeeze(0),
339 "labels": torch.tensor(row["label"], dtype=torch.long),
340 }
341```
342
343## Reproducibility
344
345```python
346# src/utils/reproducibility.py
347import torch
348import random
349import numpy as np
350
351def set_seed(seed: int) -> None:
352 random.seed(seed)
353 np.random.seed(seed)
354 torch.manual_seed(seed)
355 torch.cuda.manual_seed_all(seed)
356
357def set_deterministic(enabled: bool = True) -> None:
358 torch.backends.cudnn.deterministic = enabled
359 torch.backends.cudnn.benchmark = not enabled
360 torch.use_deterministic_algorithms(enabled, warn_only=True)
361```
362
363Always call `set_seed(config.seed)` before any model initialization or data loading.
364
365## Testing ML Code
366
367### Test Model Forward Pass
368
369```python
370# tests/test_model.py
371import pytest
372import torch
373from src.models.transformer import TransformerModel
374from src.config import ModelConfig
375
376@pytest.fixture
377def model_config() -> ModelConfig:
378 return ModelConfig(
379 name="transformer",
380 hidden_dim=64, # Small for tests
381 num_layers=2,
382 num_heads=4,
383 dropout=0.0, # Deterministic for tests
384 vocab_size=1000,
385 )
386
387def test_forward_output_shape(model_config: ModelConfig) -> None:
388 model = TransformerModel(model_config)
389 batch_size, seq_len = 4, 32
390
391 input_ids = torch.randint(0, model_config.vocab_size, (batch_size, seq_len))
392 attention_mask = torch.ones(batch_size, seq_len)
393
394 output = model(input_ids=input_ids, attention_mask=attention_mask)
395
396 assert output.logits.shape == (batch_size, seq_len, model_config.vocab_size)
397
398def test_model_trains_on_small_batch(model_config: ModelConfig) -> None:
399 """Verify the model can overfit a tiny batch — sanity check for training."""
400 model = TransformerModel(model_config)
401 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
402
403 input_ids = torch.randint(0, model_config.vocab_size, (2, 16))
404 labels = torch.randint(0, model_config.vocab_size, (2, 16))
405
406 initial_loss = None
407 for _ in range(20):
408 output = model(input_ids=input_ids, labels=labels)
409 loss = output.loss
410 if initial_loss is None:
411 initial_loss = loss.item()
412 optimizer.zero_grad()
413 loss.backward()
414 optimizer.step()
415
416 assert loss.item() < initial_loss * 0.5, "Model should overfit a tiny batch"
417```
418
419### Test Data Pipeline
420
421```python
422# tests/test_dataset.py
423def test_dataset_length(sample_dataset: TextDataset) -> None:
424 assert len(sample_dataset) == 3
425
426def test_dataset_item_shape(sample_dataset: TextDataset) -> None:
427 item = sample_dataset[0]
428 assert item["input_ids"].shape == (32,)
429 assert item["attention_mask"].shape == (32,)
430 assert "labels" in item
431```
432
433## Conventions
434
435### Code Style
436
437- Full type annotations everywhere. `mypy --strict` or `pyright` must pass.
438- Use `pathlib.Path` for all file paths — never string concatenation.
439- Prefer `torch.Tensor` type hints. Use `torch.no_grad()` as a decorator for eval functions.
440- Device handling: accept device as a parameter or resolve once in Trainer, not scattered throughout.
441- Use `model.train()` / `model.eval()` explicitly — never assume the mode.
442
443### Experiment Tracking
444
445- Every run saves a frozen `config.yaml` to its output directory.
446- Log metrics as structured data (JSON or TensorBoard scalars), not print statements.
447- Checkpoint both `best.pt` (best validation) and `last.pt` (resume-from).
448- Never hardcode hyperparameters in code — always pull from config.
449
450### GPU / Memory
451
452- Use `torch.amp.autocast` for mixed precision — don't manually cast tensors.
453- Use `optimizer.zero_grad(set_to_none=True)` — saves memory.
454- Free unused tensors: don't accumulate `.item()` calls in loops without detaching.
455- Profile first, optimize second: `torch.profiler` or `--profile` flag.
456
457## Gemini-Specific Guidance
458
459When working in this repo:
460
461- **Ingest full configs before suggesting changes.** Read the experiment YAML and `config.py` together to understand all hyperparameters and their constraints.
462- **Read the full model definition.** Architecture changes need full context — partial reads lead to shape mismatches.
463- **Check data pipeline end-to-end** before modifying transforms or dataset classes. Data shape changes propagate to the model.
464- **Reference training outputs** (`outputs/*/metrics.json`) to understand what's been tried and what the current baseline is.
465- **For debugging training issues,** read the full training loop (trainer.py) plus the config — most bugs are config/shape mismatches, not logic errors.
466

Community feedback

0 found this helpful

Works with: