Bigram Model - Next Word Prediction#

Levels: level-4
Data structures: hash-table, string, array
Patterns: hashing
  • This is a standard “bigram / Markov chain” exercise (not a single canonical LeetCode).

Description#

  • Input: a list of sentences, where each sentence is a list of words/tokens.
  • Build a model that counts which words follow a given word (a bigram frequency table).
  • Query: given a word w, predict a next word:
    • Option A (deterministic): return the most frequent next word
    • Option B (probabilistic): sample a next word proportional to its observed frequency
  • If w was never seen (or has no following word), return an empty string.

Example#

1Training data:
2[
3  ["I", "am", "sam"],
4  ["sam", "i", "am"],
5  ["i", "like", "green", "eggs", "and", "ham"]
6]
7
8Possible:
9most_common_next("i") -> "am"

Python Solution#

 1from collections import Counter
 2from collections import defaultdict
 3import random
 4from typing import DefaultDict, Dict, List, Optional
 5
 6BigramModel = DefaultDict[str, Counter]
 7
 8def train_bigrams(sentences: List[List[str]]) -> BigramModel:
 9    """
10    Build next-word counts:
11      model[w][next_w] += 1
12
13    Time:  O(total tokens)
14    Space: O(number of observed bigrams)
15    """
16    model: BigramModel = defaultdict(Counter)
17
18    for sent in sentences or []:
19        if not sent:
20            continue
21        for i in range(len(sent) - 1):
22            w, nxt = sent[i], sent[i + 1]
23            model[w][nxt] += 1
24
25    return model
26
27def most_common_next(w: str, model: BigramModel) -> str:
28    """
29    Return the most frequent next token after w, else "".
30    """
31    if not w or w not in model or not model[w]:
32        return ""
33    # Counter.most_common(1) returns [(word, count)]
34    return model[w].most_common(1)[0][0]
35
36def sample_next(w: str, model: BigramModel, rng: Optional[random.Random] = None) -> str:
37    """
38    Sample next token proportional to observed frequencies, else "".
39    """
40    if not w or w not in model or not model[w]:
41        return ""
42    rng = rng or random.Random()
43
44    counter = model[w]
45    total = sum(counter.values())
46    r = rng.randrange(total)  # integer in [0, total-1]
47
48    cum = 0
49    for nxt, c in counter.items():
50        cum += c
51        if r < cum:
52            return nxt
53
54    return ""  # defensive; shouldn't happen