
Vector Math with Tokens 🧮
Can you perform arithmetic on word meanings? We dive into the fascinating world of vector analogies, solving challenges like 'King - Man + Woman = ?' using high-dimensional embeddings.
One of the most surprising properties of word embeddings is their ability to encode complex relationships as simple vector offsets. In this lesson, we explore whether we can perform "math" on tokens and their underlying latent vectors.
This content is adapted from A deep understanding of AI language model mechanisms. It has been curated and organized for educational purposes on this portfolio. No copyright infringement is intended.
🚀 The Core Concept
Semantic relationships in language can be represented as geometric offsets in a high-dimensional vector space:
- Analogy Vectors: The vector difference between
ManandWomanis roughly the same as the difference betweenKingandQueen. - Tokens vs. Embeddings: While token IDs are just arbitrary integers (where $5 + 3$ makes no semantic sense), their corresponding embedding vectors contain rich statistical information that does support arithmetic.
- Semantic Algebra: By adding and subtracting these dense vectors, we can navigate the "meaning space" of a model to discover synonyms, analogies, and clusters.
1. Environment Setup
We'll use numpy for vector math and matplotlib to visualize how numbers map to the discrete token space.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
# highres plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')2. Loading GPT-2 and Extracting Weights
We use GPT-2's Word Token Embedding (WTE) matrix. Every token ID acts as an index into this matrix to retrieve a 768-dimensional representation.
from transformers import GPT2Model,GPT2Tokenizer
# pretrained GPT-2 model and tokenizer
gpt2 = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
embeddings = gpt2.wte.weight.detach().numpy()/Users/drippy/.pyenv/versions/3.12.6/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm3. Mapping Numbers to Tokens
Tokenization is arbitrary. The number 1 is token 16, while 10 is token 940. There is no mathematical relationship between the token ID integers and the numeric value of the strings they represent.
# create some numbers
numbers = np.arange(11)
numbers = np.concatenate( (numbers,10*numbers[2:],100*numbers[2:]), axis=0)
# initialize token vector
numTokenLabels = np.zeros(len(numbers))
# get and report the tokens
for i,n in enumerate(numbers):
# get the first token for this number
numTokenLabels[i] = tokenizer.encode(str(n))[0]
print(f'The number {n:5} is token(s) {tokenizer.encode(str(n))}')The number 0 is token(s) [15]
The number 1 is token(s) [16]
The number 2 is token(s) [17]
The number 3 is token(s) [18]
The number 4 is token(s) [19]4. Visualizing Token Discontinuity
Plotting the numbers against their token IDs reveals a jagged, non-linear pattern. This proves that you cannot perform meaningful math directly on token IDs.
plt.figure(figsize=(10,4))
plt.plot(numbers,numTokenLabels,color=[.5,.5,.5],linewidth=.5)
plt.scatter(numbers,numTokenLabels,c=np.arange(len(numbers)),s=100,marker='s',cmap='plasma_r',zorder=10)
plt.gca().set(xlabel='Number (as string)',ylabel='Token value',xticks=numbers,xlim=[numbers[0]-15,numbers[-1]+15])
plt.show()5. Deterministic Token Lengths
Integers usually map to a single token up to a certain point (usually 1 or 2 digits), while floating-point numbers quickly become fragmented into multiple subword tokens.
numnums = 99_999
int_toklens = np.zeros(numnums,dtype=int)
float_toklens = np.zeros(numnums,dtype=int)
# random numbers
ra = 5*np.random.randn(numnums)
for i in range(numnums):
int_toklens[i] = len(tokenizer.encode(str(i)))
float_toklens[i] = len(tokenizer.encode(str(ra[i])))
_,axs = plt.subplots(1,2,figsize=(12,4))
axs[0].plot(int_toklens+np.random.randn(numnums)/30,'s',markerfacecolor=[.7,.7,.9],alpha=.4)
axs[1].plot(ra,float_toklens+np.random.randn(numnums)/50,'o',markerfacecolor=[.7,.9,.7],alpha=.4)
axs[0].set(xlabel='Number',ylabel='Token length',yticks=range(int_toklens.max()+2),title='Token lengths of integers')
axs[1].set(xlabel='Number',ylabel='Token length',title='Token lengths of floating-point numbers')
plt.tight_layout()
plt.show()6. The "Token Math" Failure
Multiplying the token IDs of 5 (index 20) and 3 (index 18) results in a number (360) that has nothing to do with 15. In fact, as shown below, multiplying the full IDs quickly exceeds the vocabulary size!
# the equation and its tokens
eq = '5 x 3 ='
tokens = tokenizer.encode(eq)
print(f'{eq} -> {tokens}')
print(f'Product of tokens = {np.prod(tokens)}')5 x 3 = -> [20, 2124, 513, 796]
Product of tokens = 17346623040Trying to decode the product 17346623040 will fail because it is not a valid token ID in GPT-2's 50,257-item vocabulary.
7. Does Math Work in Embeddings?
If we instead take the embedding vectors for 3 and 5, add them together, and then find the closest vector in the entire vocabulary (the "unembedding" process), we see that the sum of the vectors for 5 and 3 is... still most similar to 5.
# isolate the embedding vectors
t5 = tokenizer.encode('5')
t3 = tokenizer.encode('3')
e5 = embeddings[t5,:].squeeze()
e3 = embeddings[t3,:].squeeze()
# vector math
theirSum = e3+e5
theirProd = e3*e5
# plot the vectors
plt.figure(figsize=(12,4))
plt.plot(e3,label='3')
plt.plot(e5,label='5')
plt.plot(theirSum,label='3+5')
plt.plot(theirProd,label='3x5')
plt.gca().set(xlabel='Embeddings dimension',ylabel='Value',xlim=[0,len(e3)])
plt.legend()
plt.show()8. The Result of Semantic Addition
Adding two number vectors doesn't perform mathematical addition. It performs semantic mixing. The vector $V(3) + V(5)$ results in a point in latent space that is statistically "number-like," but doesn't necessarily land on $V(8)$.
# unembedding via matrix multiplication
sumUnembedding = theirSum @ embeddings.T
prodUnembedding = theirProd @ embeddings.T
# find the argmax output (closest word)
print(f'Max embedding of 5+3 = "{tokenizer.decode(np.argmax(sumUnembedding))}"')
print(f'Max embedding of 5x3 = "{tokenizer.decode(np.argmax(prodUnembedding))}"')Max embedding of 5+3 = "5"
Max embedding of 5x3 = " Weinstein"Why "Weinstein"? Multiplication in high-dimensional space often results in noisy, nonsensical vectors that "crash" into unrelated parts of the vocabulary. Addition is much more stable for semantic exploration.