> **Note:** The canonical experience is the interactive HTML page: [Word2Vec Code](https://www.quantml.org/topics/word2vec/code). This file is a text mirror for search engines and AI tools.

# Word2Vec Code: Annotated Walkthrough

This is the actual Python script that trained the Word2Vec model you explored in the other tabs. Every visualization, every number, every animation was generated by this code. We've split it into 7 sections. Hover any highlighted line for an explanation, and use the cross-reference links to jump to the corresponding visual in other tabs.

- ~480 lines of Python
- PyTorch + NumPy + scikit-learn
- 64 inline annotations

---

## Section 01 — The Corpus

**Subtitle:** Loading text, removing stopwords, building a vocabulary.

```python
# ── stopwords ──────────────────────────────────────────────────────
STOPWORDS = frozenset({                 # A frozen (immutable) set of ~100 common English words that
    "the", "a", "an", "is", "am", "are", "was", "were", "be", "been",   # carry little semantic meaning. Using frozenset gives O(1) lookup.
    "being", "in", "on", "at", "to", "of", "and", "or", "but", "with",
    "by", "for", "from", "up", "down", "out", "into", "over", "under",
    "between", "his", "her", "its", "their", "he", "she", "it", "they",
    "we", "you", "i", "this", "that", "these", "those", "has", "had",
    "have", "do", "does", "did", "will", "would", "could", "should",
    "may", "might", "can", "shall", "not", "no", "nor", "so", "very",
    "too", "also", "just", "then", "than", "as", "if", "when", "where",
    "how", "what", "which", "who", "whom", "there", "here", "all",
    "each", "every", "both", "few", "many", "much", "some", "any",
    "most", "other", "such", "only", "own", "same", "about", "after",
    "before", "while", "during", "through", "above", "below", "near",
})
```

```python
def load_corpus(filepath: str):
    raw, cleaned = [], []
    with open(filepath) as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            tokens = line.lower().split()          # Split each sentence into lowercase words. All punctuation
                                                   # is already removed in our curated corpus.
            raw.append(tokens)
            content = [t for t in tokens if t not in STOPWORDS]
                                                   # Remove common words like 'the', 'is', 'and' that don't
                                                   # carry semantic meaning.
            if len(content) >= 2:                  # Need at least 2 content words to form a (center, context) pair.
                cleaned.append(content)
    return raw, cleaned


def build_vocab(sentences, min_count=3):
    counts = Counter()
    for s in sentences:
        counts.update(s)
    vocab = sorted(w for w, c in counts.items() if c >= min_count)
                                                   # Keep only words appearing 3+ times. Alphabetical sort
                                                   # ensures consistent indexing across runs.
    w2i = {w: i for i, w in enumerate(vocab)}      # Word-to-index: maps each word to its position in the vocabulary.
    i2w = {i: w for w, i in w2i.items()}           # Index-to-word: reverse lookup, used when converting outputs back to words.
    return vocab, w2i, i2w, counts
```

> See Chapter 2: Words Are Defined by Their Friends

**Corpus stats after loading:**

| Stat | Value |
|---|---|
| Sentences | 545 |
| Vocabulary | 96 words |
| Training pairs | 1,751 |

---

## Section 02 — Training Pairs

**Subtitle:** Sliding a context window over each sentence to extract (center, context) pairs.

```python
def generate_pairs(sentences, w2i, window):
    pairs = []
    for sent in sentences:
        idxs = [w2i[w] for w in sent if w in w2i]
                                                   # Convert words to integer indices, skipping any word
                                                   # not in our vocabulary.
        for i, c in enumerate(idxs):
            lo, hi = max(0, i - window), min(len(idxs), i + window + 1)
                                                   # Compute the window boundaries, clamping to sentence
                                                   # edges so we don't go out of bounds.
            for j in range(lo, hi):
                if j != i:                         # Skip self-pairing: a word isn't its own context.
                    pairs.append((c, idxs[j]))     # Each pair is (center_idx, context_idx).
                                                   # With window=3, each center word produces up to 6 pairs.
    return pairs
```

> See Chapter 2: Sliding Window

**Interactive: Try It — Sliding Window** (on the live page)

Select a sentence from the corpus. Click any in-vocabulary word to set it as the center. Context words within the window are highlighted green. Strikethrough words are filtered out.

- **Blue** — center word (click to change)
- **Green** — context words (window = 2)
- **Strikethrough** — filtered out (stopword or rare)

Generated pairs shown as: `center → context`

| Stat | Value |
|---|---|
| Window size | 2 |
| Total pairs | 1,751 |

---

## Section 03 — The Model

**Subtitle:** A two-layer neural network: embed → score → predict.

```python
# train.py: SkipGram
class SkipGram(nn.Module):
    def __init__(self, V: int, d: int):
        super().__init__()                         # PyTorch's base class for all neural network modules.
                                                   # Gives us automatic gradient tracking.
        self.V, self.d = V, d                      # Store vocab size (V=96) and embedding dimension (d=8).
        self.W = nn.Embedding(V, d)                # The center embedding matrix: 96 words × 8 dims = 768 params.
                                                   # This is where word meanings live.
        self.W_prime = nn.Linear(d, V, bias=False) # The output weight matrix: 8 → 96. Scores how likely each
                                                   # word is as a context word. Usually discarded after training.
        nn.init.uniform_(self.W.weight, -0.5, 0.5)   # Initialize all 768 parameters to random values between
        nn.init.uniform_(self.W_prime.weight, -0.5, 0.5)  # -0.5 and 0.5. Starting point matters for convergence.

    def forward(self, center_idx: torch.Tensor):
        v_c = self.W(center_idx)                   # Look up the 8-dim vector for the center word.
                                                   # This is a table lookup, not a matrix multiply.
        logits = self.W_prime(v_c)                 # Multiply the 8-dim vector by W' to get 96 raw scores,
                                                   # one for every word in the vocabulary.
        return v_c.squeeze(0), logits.squeeze(0)   # Return hidden embedding and raw scores.
                                                   # squeeze(0) removes the batch dimension.
```

> See Chapter 3: The Skip-gram Architecture

**Parameter Breakdown:**

| Layer | Type | Shape | Params |
|---|---|---|---|
| W (embeddings) | Embedding | 96 × 8 | 768 |
| W′ (out) | Linear | 8 × 96 | 768 |
| **Total** | | | **1,536** |

**Forward Pass Pipeline:**

```
one-hot [96]  →  embedding lookup W [8]  →  W′ projection [96]  →  logits [96]  →  softmax [96]  →  loss [1]
```

**Scale Comparison** (log scale — each bar step is 10× bigger):

| Model | Parameters |
|---|---|
| **This model** | **1,536** |
| word2vec-google-news | 3B |
| GPT-2 | 117M |
| GPT-3 | 175B |
| GPT-4 | ~1T+ |

This is a tiny toy model trained on a curated corpus of 545 sentences, designed to teach concepts, not for production use.

---

## Section 04 — The Training Loop

**Subtitle:** 15,000 steps of forward pass → loss → backprop → weight update.

The training loop runs 15,000 times. Each iteration picks one (center, context) pair and adjusts the weights to make the model slightly better at predicting that context word. Here's how a single step works:

**1. Shuffle & Sample**
Pairs are shuffled at the start. Each step picks the next (center, context) pair. When all pairs are used, we reshuffle, starting a new epoch.

```python
def train(model, pairs, i2w, steps, lr, dense_steps, record_interval):
    loss_fn = nn.CrossEntropyLoss()    # Cross-entropy loss: L = -log P(correct word).
                                       # Measures how surprised the model is by the true answer.
    random.shuffle(pairs)              # Shuffle pairs at the start so the model doesn't see
    idx = 0                            # them in a predictable order.
```

**2. Forward Pass**
Look up the center word's embedding from W, then multiply by W′ to get 96 raw scores (logits). This is the model's prediction of which words are likely neighbors.

`v_c = W[center],  logits = W′ · v_c`

**3. Loss Computation**
Cross-entropy loss measures how surprised the model is by the true context word.

`L = −log P(context | center) = −log softmax(logits)[context]`

**4. Backpropagation**
PyTorch automatically computes gradients. The gradient decomposes into an attractive force (toward the true context) and a repulsive force (away from wrong predictions).

`∂L/∂v_c = −u_o + Σ P(w|c)·u_w`

**5. Weight Update**
Subtract the gradient (scaled by learning rate) from the parameters. Vanilla SGD — no momentum, no Adam.

`θ ← θ − η · ∂L/∂θ`

```python
    for step in range(1, steps + 1):
        if idx >= len(pairs):          # When we've used all pairs, reshuffle and start over; new epoch.
            random.shuffle(pairs)
            idx = 0
        ci, oi = pairs[idx]            # Unpack: ci = center word index, oi = context (outside) word index.
        idx += 1

        ct, ot = torch.tensor([ci]), torch.tensor([oi])
                                       # Convert integer indices to PyTorch tensors.
        model.zero_grad()              # Clear gradients from the previous step; PyTorch accumulates by default.
        v_c, logits = model(ct)        # Forward pass: embedding lookup then W′ multiply. Returns hidden + logits.
        loss = loss_fn(logits.unsqueeze(0), ot)
                                       # Compute cross-entropy loss. unsqueeze adds the batch dim CrossEntropyLoss expects.
        loss.backward()                # Backprop: compute ∂L/∂W and ∂L/∂W′ in one call. PyTorch traces the graph.

        with torch.no_grad():          # Disable gradient tracking; doing manual parameter updates.
            for p in model.parameters():
                p -= lr * p.grad       # SGD update: new_weight = old_weight − learning_rate × gradient.
```

> See it live: [Forward & Backprop Walkthrough](https://www.quantml.org/topics/word2vec/internals)

> **One step ≈ 0.001 seconds.** Repeat 15,000 times and words that appear in similar contexts end up with similar embeddings. The loss starts at ~4.6 (random guessing among 96 words) and its smoothed average settles around ~3.7 (the model has learned meaningful patterns).

---

## Section 05 — Recording State

**Subtitle:** Capturing the model's internal state at strategic training steps for visualization.

```python
def should_record(step, total, dense_steps, interval):
    if step <= dense_steps:    # Record every single step for the first 200 steps;
        return True             # this is where the most dramatic changes happen.
    if step % interval == 0:   # After step 200, only record every 10th step.
        return True             # The model changes more slowly.
    if step == total:           # Always record the final step.
        return True
    return False
```

```python
def capture_step(step, ci, oi, model, loss_val, v_c, logits, probs, i2w):
    V = model.V
    W_np  = model.W.weight.detach().cpu().numpy()        # Extract embedding matrix: GPU tensor → CPU → NumPy.
    Wp_np = model.W_prime.weight.detach().cpu().numpy()
    probs_np  = probs.detach().cpu().numpy()
    logits_np = logits.detach().cpu().numpy()
    hidden_np = v_c.detach().cpu().numpy()

    grad_v_c = model.W.weight.grad[ci].cpu().numpy()     # Gradient for the center word's embedding.
    grad_u_o = model.W_prime.weight.grad[oi].cpu().numpy()

    u_o       = Wp_np[oi]                                # Output vector for the true context word.
    attractive = u_o.copy()                              # Attractive force: context word's output vector
                                                         # pulls the center embedding toward it.
    repulsive  = probs_np @ Wp_np                        # Repulsive force: probability-weighted average of all
                                                         # output vectors — pushes away from wrong predictions.

    sorted_idx = np.argsort(probs_np)[::-1]
    rank = int(np.where(sorted_idx == oi)[0][0]) + 1    # Where the true context ranks among 96 words.
                                                         # Rank 1 = model got it right.
    cos = float(F.cosine_similarity(
        torch.from_numpy(hidden_np).unsqueeze(0),
        torch.from_numpy(W_np[oi]).unsqueeze(0),         # Cosine similarity between center and context embeddings.
    ))

    return {
        "step": step,
        "center_word": i2w[ci],
        "context_word": i2w[oi],
        "loss": float(loss_val),
        "W": W_np.tolist(),
        "W_prime": Wp_np.tolist(),
        "hidden": hidden_np.tolist(),
        "logits":        {i2w[i]: float(logits_np[i]) for i in range(V)},
        "probabilities": {i2w[i]: float(probs_np[i])  for i in range(V)},
        "grad_v_c":          grad_v_c.tolist(),
        "grad_u_o":          grad_u_o.tolist(),
        "attractive_force":  attractive.tolist(),
        "repulsive_force":   repulsive.tolist(),
        "cosine_center_context": cos,
        "context_rank":  rank,
        "context_prob":  float(probs_np[oi]),
    }
```

> Every number in the [Forward & Backprop Walkthrough](https://www.quantml.org/topics/word2vec/internals) came from this function.

**Recording Strategy:**

| Phase | Range | Frequency | Recordings |
|---|---|---|---|
| Early dynamics | Steps 0–200 | Every step | 201 |
| Later training | Steps 210–15,000 | Every 10th step | ~1,480 |

Total: ~1,681 recorded steps × ~45 KB each ≈ **75 MB** of training state data.

**What Gets Captured (per step):**

| Field | Description |
|---|---|
| `W, W′` | Full weight matrices at this step |
| `hidden` | Center word's embedding vector (v_c) |
| `logits` | Raw scores for all 96 vocabulary words |
| `probabilities` | Softmax output: P(w\|center) |
| `grad_v_c` | Gradient for the center embedding |
| `grad_u_o` | Gradient for the context output vector |
| `attractive_force` | Pull toward the true context word |
| `repulsive_force` | Push away from wrong predictions |
| `cosine_center_context` | Alignment between center and context |
| `context_rank` | Where the true context ranks (1 = best) |
| `context_prob` | Probability assigned to the correct answer |

---

## Section 06 — Visualization Pipeline

**Subtitle:** Projecting 8-dimensional embeddings to 2D with PCA, and saving structured output files.

```python
def add_projections(records, W_snapshots, i2w, method):
    V   = W_snapshots[0].shape[0]
    pca = PCA(n_components=2, random_state=42)      # Finds the 2 directions of greatest variance in 8D space.
    if method == "final_pca":
        pca.fit(W_snapshots[-1])                    # Fit PCA on the final embeddings. All snapshots projected
    else:                                           # with the same axes for consistent animation.
        pca.fit(np.vstack(W_snapshots))

    var = pca.explained_variance_ratio_.tolist()
    for i, rec in enumerate(records):
        proj = pca.transform(W_snapshots[i])
        rec["embeddings_2d"] = {i2w[j]: proj[j].tolist() for j in range(V)}
                                                    # Store 2D coords for every word at every step;
                                                    # this becomes the scatter plot animation.
        rec["pca_variance_explained"] = var         # How much info the 2 axes capture. Typically ~60–70%.
```

```python
def save_local(records, all_losses, vocab, counts, config, raw_sentences, out_dir):
    """Write all training data as structured local files.

    Layout designed for fast frontend access:
      - run_metadata.json         loaded once on page init
      - loss_curve.json           loaded once for the chart
      - step_index.json           loaded once, tells frontend which steps exist
      - embeddings_timeline.json  loaded once for the scatter animation
      - steps/{N}.json            loaded on-demand when user inspects a step
    """
    steps_dir = os.path.join(out_dir, "steps")
    os.makedirs(steps_dir, exist_ok=True)

    # ── run_metadata.json (lightweight, ~50 KB) ──────────────────────
    meta = {
        "topic": "word2vec",
        "corpus_name": "curated_v1",
        "vocab": vocab,
        "vocab_size": len(vocab),
        "embed_dim": config["embed_dim"],
        "config": config,
        "total_steps": config["num_steps"],
        "recorded_steps": len(records),
        "word_frequencies": {w: counts[w] for w in vocab},
                                                    # How often each word appeared; shown in frequency charts.
        "raw_sentences": [" ".join(s) for s in raw_sentences],
                                                    # Original sentences before stopword removal; corpus preview.
    }
    _write(os.path.join(out_dir, "run_metadata.json"), meta, indent=2)

    # ── loss_curve.json (every training step, ~300 KB) ───────────────
    loss_curve = [{"step": i + 1, "loss": round(l, 6)}
                  for i, l in enumerate(all_losses)]  # Every step's loss value. Enables the full loss curve chart.
    _write(os.path.join(out_dir, "loss_curve.json"), loss_curve)

    # ── step_index.json (which steps have full recordings) ───────────
    step_index = [rec["step"] for rec in records]    # Frontend uses this to build the step selector.
    _write(os.path.join(out_dir, "step_index.json"), step_index)

    # ── embeddings_timeline.json (~3 MB) ─────────────────────────────
    # Compact: just step number → {word: [x, y]} for the scatter
    # animation. Frontend loads this once and scrubs through it.
    timeline = []
    for rec in records:
        entry = {"step": rec["step"], "positions": rec["embeddings_2d"]}
        if rec["loss"] is not None:
            entry["loss"]         = round(rec["loss"], 6)
            entry["center_word"]  = rec["center_word"]
            entry["context_word"] = rec["context_word"]
        timeline.append(entry)
    _write(os.path.join(out_dir, "embeddings_timeline.json"), timeline)

    # ── steps/{N}.json (one file per recorded step, ~45 KB each)
    for rec in records:
        step_path = os.path.join(steps_dir, f"{rec['step']}.json")
        step_data = {k: v for k, v in rec.items()
                     if k != "embeddings_2d" and k != "pca_variance_explained"}
        step_data["embeddings_2d"]          = rec["embeddings_2d"]
                                              # Per-step files contain everything: weights, gradients,
        step_data["pca_variance_explained"] = rec.get("pca_variance_explained")
                                              # probabilities, embeddings. ~45 KB each.
        _write(step_path, step_data)
```

> See Chapter 4: [Watch It Learn](https://www.quantml.org/topics/word2vec) (the PCA scatter plot)  
> These files power the entire [Inside the Model](https://www.quantml.org/topics/word2vec/internals) tab.

**Output File Structure:**

```
data/word2vec/
├── run_metadata.json         ~25 KB   Vocab, config, word frequencies, raw sentences
├── loss_curve.json           ~500 KB  All 15,000 steps' loss values for the chart
├── step_index.json           ~12 KB   Which steps have full recordings
├── embeddings_timeline.json  ~8 MB    2D projections at every recorded step (scatter animation)
├── model_definition.json     ~8 KB    Model source code, layer info, comparisons
├── corpus_stats.json         ~8 KB    Frequency distribution, sample sentences
├── similarity_matrix.json    ~80 KB   Word-to-word cosine similarity grid
├── weight_snapshots.json     ~130 KB  W and W′ matrices at key training steps
├── w_vs_wprime.json          ~12 KB   Cosine similarity and norms between W and W′ per word
├── neighbor_evolution.json   ~44 KB   How nearest neighbors change over training for key words
├── analogy_evolution.json    ~16 KB   Analogy test results at each training milestone
├── word_convergence.json     ~16 KB   How quickly each word's embedding converges to its final position
└── 📁 steps/                 ~75 MB   Per-step full state (1,681 files × ~45 KB)
```

**Why PCA?**
Our embeddings live in 8 dimensions, but screens are 2D. PCA (Principal Component Analysis) finds the two directions that capture the most variance: the "best possible flat photo" of a 3D+ object. We fit PCA on the final embeddings and use those same axes for all steps, so the animation shows genuine movement rather than axis rotation.

---

## Section 07 — Run It Yourself

**Subtitle:** The CLI entry point with all configurable hyperparameters.

```python
def main():
    ap = argparse.ArgumentParser()   # Standard Python argument parser for command-line options.
    ap.add_argument("--corpus",  default=str(Path(__file__).with_name("corpus.txt")))
    ap.add_argument("--output-dir",
                     default=str(Path(__file__).resolve().parent / "output"),
                     help="Where to write output files (default: ./output/)")
    ap.add_argument("--steps",          type=int,   default=15000)
                                         # Total training steps. More = better embeddings; diminishing returns after ~10K.
    ap.add_argument("--dim",            type=int,   default=8)
                                         # Embedding dimension. 8 for visualization; real models use 100–300.
    ap.add_argument("--lr",             type=float, default=0.1)
                                         # Learning rate. Too high → diverge, too low → slow.
    ap.add_argument("--window",         type=int,   default=3)
                                         # Context window size. Larger = more semantic, smaller = more syntactic.
    ap.add_argument("--min-count",      type=int,   default=3)
                                         # Minimum word frequency to include in vocabulary.
    ap.add_argument("--seed",           type=int,   default=42)
                                         # Random seed for reproducibility; same seed = same results.
    ap.add_argument("--dense-steps",    type=int,   default=200,
                     help="Record every step for the first N steps")
                                         # Capture early learning dynamics at full resolution.
    ap.add_argument("--record-interval",type=int,   default=10,
                     help="After dense-steps, record every N-th step")
                                         # Balances detail vs file size for later training.
    ap.add_argument("--pca", default="final_pca", choices=["final_pca", "global_pca"])
    args = ap.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # ── load ──────────────────────────────────────────────────────────
    print("Loading corpus …")
    raw, cleaned = load_corpus(args.corpus)
    vocab, w2i, i2w, counts = build_vocab(cleaned, args.min_count)
    pairs = generate_pairs(cleaned, w2i, args.window)

    print(f"  {len(raw)} sentences  →  {len(cleaned)} after stopword removal")
    print(f"  Vocabulary: {len(vocab)} words")
    print(f"  Training pairs: {len(pairs)}")
    print(f"  Epochs covered: {args.steps / len(pairs):.1f}")

    # ── train ─────────────────────────────────────────────────────────
    model = SkipGram(len(vocab), args.dim)   # Create the model: vocab_size=96, embed_dim=8 → 1,536 parameters.
    records, W_snaps, all_losses = train(
        model, pairs, i2w, args.steps, args.lr,
        args.dense_steps, args.record_interval,
    )

    # ── 2D projections ────────────────────────────────────────────────
    add_projections(records, W_snaps, i2w, args.pca)
                                             # Project all weight snapshots to 2D for the scatter plot animation.

    # ── save ──────────────────────────────────────────────────────────
    save_local(records, all_losses, vocab, counts, config, raw, args.output_dir)
                                             # Write everything to disk — ~80 MB total across all output files.

if __name__ == "__main__":
    main()
```

[View on GitHub](https://github.com/quantml/learn/tree/main/word2vec) — Clone the repo, run the training script, and experiment with hyperparameters.

**Quick Start:**

```bash
pip install torch numpy scikit-learn

python train.py \
  --corpus corpus.txt \
  --steps 15000 \
  --dim 8 \
  --lr 0.1 \
  --window 3
```

**All CLI Arguments:**

| Argument | Default | Description |
|---|---|---|
| `--corpus` | corpus.txt | Path to the training text file |
| `--output-dir` | ./output | Where to write output JSON files |
| `--steps` | 15000 | Total training iterations |
| `--dim` | 8 | Embedding vector dimensionality |
| `--lr` | 0.1 | Learning rate (SGD step size) |
| `--window` | 3 | Context window radius |
| `--min-count` | 3 | Min word frequency to include |
| `--seed` | 42 | Random seed for reproducibility |
| `--dense-steps` | 200 | Record every step for first N |
| `--record-interval` | 10 | Record every N-th step after dense |
| `--pca` | final_pca | PCA method: final_pca or global_pca |

**Experiments to Try:**

**Change embedding dimension** (`--dim 4 vs --dim 8 vs --dim 16`)
Fewer dimensions = less capacity but faster training and easier to visualize. 4 dimensions may struggle with 96 words; 16 should do better.

**Change learning rate** (`--lr 0.01 vs --lr 0.1 vs --lr 0.5`)
Too high → loss diverges (NaN). Too low → barely learns in 15K steps. 0.1 is the sweet spot for this corpus size.

**Change window size** (`--window 2 vs --window 3 vs --window 5`)
Larger window = more semantic relationships (topical similarity). Smaller window = more syntactic relationships (word-type similarity).

**Change training steps** (`--steps 5000 vs --steps 15000 vs --steps 25000`)
Watch diminishing returns: loss drops fast in the first 5K, then plateaus. More steps = better analogies but slower.

---

## Related sections

- [Word2Vec Story](https://www.quantml.org/topics/word2vec) — Conceptual walkthrough using these trained artifacts
- [Inside the Model](https://www.quantml.org/topics/word2vec/internals) — Interactive inspection of the weights and state this script produced
- [Quiz](https://www.quantml.org/topics/word2vec/quiz) — Test your understanding of the implementation
