Search icon CANCEL
Subscription
0
Cart icon
Your Cart (0 item)
Close icon
You have no products in your basket yet
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Conferences
Free Learning
Arrow right icon

Exploring Token Generation Strategies

Save for later
  • 8 min read
  • 28 Aug 2023

article-image

Introduction

This article discusses different methods for generating sequences of tokens using language models, specifically focusing on the context of predicting the next token in a sequence. The article explains various techniques to select the next token based on the predicted probability distribution of possible tokens.

Language models predict the next token based on n previous tokens. Models try to extract information from n previous tokens as far as they can. Transformer models aggregate information from all the n previous tokens. Tokens in a sequence communicate with one another and exchange their information. At the end of the communication process, tokens are context-aware and we use them to predict their own next token. Each token separately goes to some linear/non-linear layers and the output is unnormalized logits. Then, we apply Softmax on logits to convert them into probability distributions. Each token has its own probability distribution over its next token:

exploring-token-generation-strategies-img-0


Exploring Methods for Token Selection

When we have the probability distribution of tokens, it’s time to pick one token as the next token. There are four methods for selecting the suitable token from probability distribution:

●    Greedy or naive method: Simply select the token that has the highest probability from the list. This is a deterministic method.
●    Beam search: It receives a parameter named beam size and based on it, the algorithm tries to use the model to predict multiple times to find a suitable sentence, not just a token. This is a deterministic method.
●    Top-k sampling: Select the top k most probable tokens and shut off other tokens (make their probability -inf) and sample from top k tokens. This is a sampling method.
●    Nucleus sampling: Select the top most probable tokens and shut off other tokens but with a difference that is a dynamic selection of most probable tokens. Not just a crisp k.

Greedy method

This is a simple and fast method and only needs one prediction. Just select the most probable token as the next token. Greedy methods can be efficient on arithmetic tasks. But, it tends to get stuck in a loop and repeat tokens one after another. It also kills the diversity of the model by selecting the tokens that occur frequently in the training dataset.
Here’s the code that converts unnormalized logits(simply the output of the network) into probability distribution and selects the most probable next token:

probs = F.softmax(logits, dim=-1)
next_token = probs.argmax() 

Beam search

Beam search produces better results and is slower because it runs the model multiple times so that it can create n sequences, where n is beam size. This method selects top n tokens and adds them to the current sequence and runs the model on the made sequences to predict the next token. And this process continues until the end of the sequence. Computationally expensive, but more quality.

exploring-token-generation-strategies-img-1

 

Based on this search, the algorithm returns two sequences:

exploring-token-generation-strategies-img-2

Then, how do we select the final sequence? We sum up the loss for all predictions and select the sequence with the lowest loss.

Simple sampling

We can select tokens randomly based on their probability. The more the probability, the more the chance of being selected. We can achieve this by using multinomial method:

logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1)

This is part of the model we implemented in the “transformer building blocks” blog and the code can be found here. The torch.multinomial receives the probability distribution and selects n samples. Here’s an example:

Unlock access to the largest independent learning library in Tech for FREE!
Get unlimited access to 7500+ expert-authored eBooks and video courses covering every tech area you can think of.
Renews at €18.99/month. Cancel anytime
In [1]: import torch
In [2]: probs = torch.tensor([0.3, 0.6, 0.1])
In [3]: torch.multinomial(probs, num_samples=1)
Out[3]: tensor([1])
In [4]: torch.multinomial(probs, num_samples=1)
Out[4]: tensor([0])
In [5]: torch.multinomial(probs, num_samples=1)
Out[5]: tensor([1])
In [6]: torch.multinomial(probs, num_samples=1)
Out[6]: tensor([0])
In [7]: torch.multinomial(probs, num_samples=1)
Out[7]: tensor([1])
In [8]: torch.multinomial(probs, num_samples=1)
Out[8]: tensor([1])

We ran the method six times on probs, and as you can see it selects 0.6 four times and 0.3 two times because 0.6 is higher than 0.3.

Top-k sampling

If we want to make the previous sampling method better, we need to limit the sampling space. Top-k sampling does this. K is a parameter that Top-k sampling uses to select top k tokens from the probability distribution and sample from these k tokens. 

Here is an example of top-k sampling:

In [1]: import torch
In [2]: logit = torch.randn(10)
In [3]: logit
Out[3]: 
tensor([-1.1147,  0.5769,  0.3831, -0.5841,  1.7528, -0.7718, -0.4438,  0.6529,
        0.1500,  1.2592])
In [4]: topk_values, topk_indices = torch.topk(logit, 3)
In [5]: topk_values
Out[5]: tensor([1.7528, 1.2592, 0.6529])
In [6]: logit[logit < topk_values[-1]] = float('-inf')
In [7]: logit
Out[7]: 
tensor([  -inf,   -inf,   -inf,   -inf, 1.7528,   -inf,   -inf, 0.6529,   -inf,
       1.2592])
In [8]: probs = logit.softmax(0)
In [9]: probs
Out[9]: 
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.5146, 0.0000, 0.0000, 0.1713, 0.0000,
       0.3141])
In [10]: torch.multinomial(probs, num_samples=1)
Out[10]: tensor([9])
In [11]: torch.multinomial(probs, num_samples=1)
Out[11]: tensor([4])
In [12]: torch.multinomial(probs, num_samples=1)
Out[12]: tensor([9])

●    We first create a fake logit with torch.randn. Supposedly logit is the raw output of a network.
●    We use torch.topk to select the top 3 values from logit. torch.topk returns top 3 values along with their indices. The values are sorted from top to bottom.
●    We use advanced indexing to select logit values that are lower than the last top 3 values. When we say logit < topk_values[-1] we mean all the numbers in logit that are lower than topk_values[-1] (0.6529). 
●    After selecting those numbers, we replace their value to float(‘-inf’), which is a negative infinite number. 
●    After replacement, we run softmax over the logit to convert it into probabilities. 
●    Now, we use torch.multinomial to sample from the probs.

Nucleus sampling

Nucleus sampling is like Top-k sampling but with a dynamic selection of top tokens instead of selecting k tokens. The dynamic selection is better when we are unsure of selecting a suitable k for Top-k sampling. Nucleus sampling has a hyperparameter named p, let us say it is 0.9, and this method selects tokens from descending order and adds up their probabilities and when we reach a cumulative sum of p, we stop. What is the cumulative sum? Here’s an example of cumulative sum:

In [1]: import torch
In [2]: logit = torch.randn(10)
In [3]: probs = logit.softmax(0)
In [4]: probs
Out[4]: 
tensor([0.0652, 0.0330, 0.0609, 0.0436, 0.2365, 0.1738, 0.0651, 0.0692, 0.0495,
       0.2031])

In [5]: [probs[:x+1].sum() for x in range(probs.size(0))]
Out[5]: 
[tensor(0.0652),
tensor(0.0983),
tensor(0.1592),
tensor(0.2028),
tensor(0.4394),
tensor(0.6131),
tensor(0.6782),
tensor(0.7474),
tensor(0.7969),
tensor(1.)]

I hope you understand how cumulative sum works from the code. We just add up n previous prob values. We can also use torch.cumsum and get the same result:

In [9]: torch.cumsum(probs, dim=0)
Out[9]: 
tensor([0.0652, 0.0983, 0.1592, 0.2028, 0.4394, 0.6131, 0.6782, 0.7474, 0.7969,
       1.0000])
Okay. Here’s a nucleus sampling from scratch:
In [1]: import torch
In [2]: logit = torch.randn(10)
In [3]: probs = logit.softmax(0)
In [4]: probs
Out[4]: 
tensor([0.7492, 0.0100, 0.0332, 0.0078, 0.0191, 0.0370, 0.0444, 0.0553, 0.0135,
       0.0305])
In [5]: sprobs, indices = torch.sort(probs, dim=0, descending=True)
In [6]: sprobs
Out[6]: 
tensor([0.7492, 0.0553, 0.0444, 0.0370, 0.0332, 0.0305, 0.0191, 0.0135, 0.0100,
       0.0078])
In [7]: cs_probs = torch.cumsum(sprobs, dim=0)
In [8]: cs_probs
Out[8]: 
tensor([0.7492, 0.8045, 0.8489, 0.8860, 0.9192, 0.9497, 0.9687, 0.9822, 0.9922,
       1.0000])
In [9]: selected_tokens = cs_probs < 0.9
In [10]: selected_tokens
Out[10]: tensor([ True,  True,  True,  True, False, False, False, False, False, False])
In [11]: probs[indices[selected_tokens]]
Out[11]: tensor([0.7492, 0.0553, 0.0444, 0.0370])
In [12]: probs = probs[indices[selected_tokens]]
In [13]: torch.multinomial(probs, num_samples=1)
Out[13]: tensor([0])

●    Convert the logit to probabilities and sort it with descending order so that we can select them from top to bottom.
●    Calculate cumulative sum.
●    Using advanced indexing, we filter out values.
●    Then, we sample from a limited and better space.

Please note that you can use a combination of top-k and nucleus samplings. It is like selecting k tokens and doing nucleus sampling on these k tokens. You can also use top-k, nucleus, and beam search.

Conclusion

Understanding these methods is crucial for anyone working with language models, natural language processing, or text generation tasks. These techniques play a significant role in generating coherent and diverse sequences of text. Depending on the specific use case and desired outcomes, readers can choose the most appropriate method to employ. Overall, this knowledge can contribute to improving the quality of generated text and enhancing the capabilities of language models.

Author Bio

Saeed Dehqan trains language models from scratch. Currently, his work is centered around Language Models for text generation, and he possesses a strong understanding of the underlying concepts of neural networks. He is proficient in using optimizers such as genetic algorithms to fine-tune network hyperparameters and has experience with neural architecture search (NAS) by using reinforcement learning (RL). He implements models starting from data gathering to monitoring, and deployment on mobile, web, cloud, etc.