Bigram Model - Next Word Prediction#
Practice Link#
- This is a standard “bigram / Markov chain” exercise (not a single canonical LeetCode).
Description#
- Input: a list of sentences, where each sentence is a list of words/tokens.
- Build a model that counts which words follow a given word (a bigram frequency table).
- Query: given a word
w, predict a next word:
- Option A (deterministic): return the most frequent next word
- Option B (probabilistic): sample a next word proportional to its observed frequency
- If
w was never seen (or has no following word), return an empty string.
Example#
1Training data:
2[
3 ["I", "am", "sam"],
4 ["sam", "i", "am"],
5 ["i", "like", "green", "eggs", "and", "ham"]
6]
7
8Possible:
9most_common_next("i") -> "am"
Python Solution#
1from collections import Counter
2from collections import defaultdict
3import random
4from typing import DefaultDict, Dict, List, Optional
5
6BigramModel = DefaultDict[str, Counter]
7
8def train_bigrams(sentences: List[List[str]]) -> BigramModel:
9 """
10 Build next-word counts:
11 model[w][next_w] += 1
12
13 Time: O(total tokens)
14 Space: O(number of observed bigrams)
15 """
16 model: BigramModel = defaultdict(Counter)
17
18 for sent in sentences or []:
19 if not sent:
20 continue
21 for i in range(len(sent) - 1):
22 w, nxt = sent[i], sent[i + 1]
23 model[w][nxt] += 1
24
25 return model
26
27def most_common_next(w: str, model: BigramModel) -> str:
28 """
29 Return the most frequent next token after w, else "".
30 """
31 if not w or w not in model or not model[w]:
32 return ""
33 # Counter.most_common(1) returns [(word, count)]
34 return model[w].most_common(1)[0][0]
35
36def sample_next(w: str, model: BigramModel, rng: Optional[random.Random] = None) -> str:
37 """
38 Sample next token proportional to observed frequencies, else "".
39 """
40 if not w or w not in model or not model[w]:
41 return ""
42 rng = rng or random.Random()
43
44 counter = model[w]
45 total = sum(counter.values())
46 r = rng.randrange(total) # integer in [0, total-1]
47
48 cum = 0
49 for nxt, c in counter.items():
50 cum += c
51 if r < cum:
52 return nxt
53
54 return "" # defensive; shouldn't happen