From 3a5828c8513aad3d6ac39b9fb6d8d7c4aa9a7e63 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 18 Jun 2024 22:53:02 +0100 Subject: [PATCH 1/4] example of beam search with batched executor --- LLama.Examples/ExampleRunner.cs | 1 + .../Examples/BatchedExecutorBeamSearch.cs | 136 ++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 LLama.Examples/Examples/BatchedExecutorBeamSearch.cs diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs index c194f6f87..09819c120 100644 --- a/LLama.Examples/ExampleRunner.cs +++ b/LLama.Examples/ExampleRunner.cs @@ -33,6 +33,7 @@ public class ExampleRunner { "Batched Executor: Rewind", BatchedExecutorRewind.Run }, { "Batched Executor: Guidance", BatchedExecutorGuidance.Run }, { "Batched Executor: LLava", BatchedExecutorLLava.Run }, + { "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run }, { "Speech Chat: Integration with Whisper.net", SpeechChat.Run }, { "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } } }; diff --git a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs new file mode 100644 index 000000000..ca3404319 --- /dev/null +++ b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs @@ -0,0 +1,136 @@ +using LLama.Batched; +using LLama.Common; +using LLama.Native; +using Spectre.Console; + +namespace LLama.Examples.Examples; + +/// +/// This demonstrates beam search using the batched executor +/// +public class BatchedExecutorBeamSearch +{ + /// + /// Set how many tokens to generate + /// + private const int TokensGenerate = 24; + + /// + /// Set how many parallel beams to keep + /// + private const int BeamsCount = 3; + + public static async Task Run() + { + // Load model weights + var parameters = new ModelParams(UserSettings.GetModelPath()); + using var model = await LLamaWeights.LoadFromFileAsync(parameters); + + var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); + + // Create an executor that can evaluate a batch of conversations together + using var executor = new BatchedExecutor(model, parameters); + + // Print some info + var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name"); + Console.WriteLine($"Created executor with model: {name}"); + + // Evaluate the initial prompt to create one conversation + var conversation = executor.Create(); + var startTokens = executor.Context.Tokenize(prompt); + conversation.Prompt(startTokens); + + // Create one beam, containing that conversation + var beams = new List(); + beams.Add(new Beam(conversation, 1.0, startTokens)); + + // Print the prompt + Console.ForegroundColor = ConsoleColor.Green; + Console.WriteLine(prompt); + + // Generate loop + for (var i = 0; i < TokensGenerate; i++) + { + await executor.Infer(); + + // Create new beams, forked from all original beams + beams = (from oldBeam in beams + from beam in oldBeam.Sample(BeamsCount) + select beam).OrderBy(a => a.CumulativeProbability).ToList(); + + // Trim down list by removing low probability beams + while (beams.Count > BeamsCount) + { + var beam = beams[0]; + AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam (prob:{beam.CumulativeProbability:P10})[/]: {beam}"); + + beam.Dispose(); + beams.RemoveAt(0); + } + } + + // Print out all remaining beams + AnsiConsole.MarkupLineInterpolated($"Final Beams:"); + beams.Reverse(); + foreach (var beam in beams) + AnsiConsole.MarkupLineInterpolated($"[green]Culling Beam (prob:{beam.CumulativeProbability:P10})[/]: {beam}"); + + Console.WriteLine("Press any key to exit demo"); + Console.ReadKey(true); + } + + private class Beam + : IDisposable + { + public readonly Conversation Conversation; + public readonly double CumulativeProbability; + public readonly IReadOnlyList Tokens; + + public Beam(Conversation conversation, double prob, IReadOnlyList tokens) + { + Conversation = conversation; + CumulativeProbability = prob; + Tokens = tokens; + } + + public void Dispose() + { + Conversation.Dispose(); + } + + public List Sample(int nbeams) + { + // Apply softmax, this calculates probabilities and sorts tokens into descending order + var logitsArr = LLamaTokenDataArray.Create(Conversation.Sample()); + logitsArr.Softmax(Conversation.Executor.Context.NativeHandle); + + // Create new forked conversations, one for each beam + var results = new List(); + for (var i = 0; i < nbeams; i++) + { + var item = logitsArr.Data.Span[i]; + + var c = Conversation.Fork(); + c.Prompt(item.id); + + var p = CumulativeProbability * item.p; + + var t = Tokens.ToList(); + t.Add(item.id); + + results.Add(new Beam(c, p, t)); + } + + // Dispose self now that child beams have spawned + Conversation.Dispose(); + return results; + } + + public override string ToString() + { +#pragma warning disable CS0618 // Type or member is obsolete + return Conversation.Executor.Context.DeTokenize(Tokens); +#pragma warning restore CS0618 // Type or member is obsolete + } + } +} \ No newline at end of file From 6d6e4cdbf726fb63e45da08df68af1fc4daca223 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 18 Jun 2024 23:20:07 +0100 Subject: [PATCH 2/4] - Printing ID of discarded beams - Tracking the set of conversation IDs that made a beam. --- .../Examples/BatchedExecutorBeamSearch.cs | 36 +++++++++---------- LLama/Native/LLamaSeqId.cs | 8 ++++- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs index ca3404319..dba7518d3 100644 --- a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs +++ b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs @@ -10,23 +10,15 @@ namespace LLama.Examples.Examples; /// public class BatchedExecutorBeamSearch { - /// - /// Set how many tokens to generate - /// - private const int TokensGenerate = 24; - - /// - /// Set how many parallel beams to keep - /// - private const int BeamsCount = 3; - public static async Task Run() { // Load model weights var parameters = new ModelParams(UserSettings.GetModelPath()); using var model = await LLamaWeights.LoadFromFileAsync(parameters); - var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); + var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "The cat sat on"); + var tokensGenerate = AnsiConsole.Ask("How many tokens to generate?", 8); + var beamsCount = AnsiConsole.Ask("How many parallel beams to keep track of?", 8); // Create an executor that can evaluate a batch of conversations together using var executor = new BatchedExecutor(model, parameters); @@ -42,27 +34,27 @@ public static async Task Run() // Create one beam, containing that conversation var beams = new List(); - beams.Add(new Beam(conversation, 1.0, startTokens)); + beams.Add(new Beam(conversation, 1.0, startTokens, [conversation.ConversationId])); // Print the prompt Console.ForegroundColor = ConsoleColor.Green; Console.WriteLine(prompt); // Generate loop - for (var i = 0; i < TokensGenerate; i++) + for (var i = 0; i < tokensGenerate; i++) { await executor.Infer(); // Create new beams, forked from all original beams beams = (from oldBeam in beams - from beam in oldBeam.Sample(BeamsCount) + from beam in oldBeam.Sample(beamsCount) select beam).OrderBy(a => a.CumulativeProbability).ToList(); // Trim down list by removing low probability beams - while (beams.Count > BeamsCount) + while (beams.Count > beamsCount) { var beam = beams[0]; - AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam (prob:{beam.CumulativeProbability:P10})[/]: {beam}"); + AnsiConsole.MarkupLineInterpolated($"[red]Culling Beam {beam.Conversation.ConversationId} (prob:{beam.CumulativeProbability:P10})[/]: {beam}"); beam.Dispose(); beams.RemoveAt(0); @@ -73,7 +65,7 @@ from beam in oldBeam.Sample(BeamsCount) AnsiConsole.MarkupLineInterpolated($"Final Beams:"); beams.Reverse(); foreach (var beam in beams) - AnsiConsole.MarkupLineInterpolated($"[green]Culling Beam (prob:{beam.CumulativeProbability:P10})[/]: {beam}"); + AnsiConsole.MarkupLineInterpolated($"[green](prob:{beam.CumulativeProbability:P10})[/]: {beam}"); Console.WriteLine("Press any key to exit demo"); Console.ReadKey(true); @@ -85,12 +77,14 @@ private class Beam public readonly Conversation Conversation; public readonly double CumulativeProbability; public readonly IReadOnlyList Tokens; + public readonly IReadOnlyList Sequence; - public Beam(Conversation conversation, double prob, IReadOnlyList tokens) + public Beam(Conversation conversation, double prob, IReadOnlyList tokens, IReadOnlyList sequence) { Conversation = conversation; CumulativeProbability = prob; Tokens = tokens; + Sequence = sequence; } public void Dispose() @@ -118,7 +112,10 @@ public List Sample(int nbeams) var t = Tokens.ToList(); t.Add(item.id); - results.Add(new Beam(c, p, t)); + var s = Sequence.ToList(); + s.Add(c.ConversationId); + + results.Add(new Beam(c, p, t, s)); } // Dispose self now that child beams have spawned @@ -131,6 +128,7 @@ public override string ToString() #pragma warning disable CS0618 // Type or member is obsolete return Conversation.Executor.Context.DeTokenize(Tokens); #pragma warning restore CS0618 // Type or member is obsolete + } } } \ No newline at end of file diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs index 8a3dae5d8..e3c6e8f43 100644 --- a/LLama/Native/LLamaSeqId.cs +++ b/LLama/Native/LLamaSeqId.cs @@ -1,4 +1,4 @@ -using System.Runtime.InteropServices; +using System.Runtime.InteropServices; namespace LLama.Native; @@ -39,4 +39,10 @@ private LLamaSeqId(int value) /// /// public static explicit operator LLamaSeqId(int value) => new(value); + + /// + public readonly override string ToString() + { + return Value.ToString(); + } } \ No newline at end of file From c871be2b14373b9b02972aca5788e3766f02e273 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 20 Jun 2024 11:56:17 +0100 Subject: [PATCH 3/4] - Minor cleanup of beam search example - Added beam normalization, to prevent `CumulativeProbability=0.0` when beams get longer --- .../Examples/BatchedExecutorBeamSearch.cs | 64 ++++++++++++++----- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs index dba7518d3..f61d2ee89 100644 --- a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs +++ b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs @@ -6,7 +6,13 @@ namespace LLama.Examples.Examples; /// -/// This demonstrates beam search using the batched executor +/// This demonstrates beam search using the batched executor. +/// +/// Beam search is a technique for finding the most likely multi-token completion from a prompt. The search keeps track of a +/// set of "beams", each beam is a possible completion and keeps track of it's cumulative probability. At each step all +/// of the current beams are split into multiple beams by extending the beam with different possible tokens (greedy sampling the +/// top N tokens), the set of _all_ beams is then trimmed down to just the most likely beams. This allows multiple possibilties to +/// be considered, and can find a higher probability result than simply greedy sampling the most likely token at every stage. /// public class BatchedExecutorBeamSearch { @@ -17,8 +23,8 @@ public static async Task Run() using var model = await LLamaWeights.LoadFromFileAsync(parameters); var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "The cat sat on"); - var tokensGenerate = AnsiConsole.Ask("How many tokens to generate?", 8); - var beamsCount = AnsiConsole.Ask("How many parallel beams to keep track of?", 8); + var tokensGenerate = AnsiConsole.Ask("How many tokens to generate?", 8); + var beamsCount = AnsiConsole.Ask("How many parallel beams to keep track of?", 8); // Create an executor that can evaluate a batch of conversations together using var executor = new BatchedExecutor(model, parameters); @@ -33,12 +39,7 @@ public static async Task Run() conversation.Prompt(startTokens); // Create one beam, containing that conversation - var beams = new List(); - beams.Add(new Beam(conversation, 1.0, startTokens, [conversation.ConversationId])); - - // Print the prompt - Console.ForegroundColor = ConsoleColor.Green; - Console.WriteLine(prompt); + var beams = new List { new Beam(conversation, 1.0, startTokens, [conversation.ConversationId]) }; // Generate loop for (var i = 0; i < tokensGenerate; i++) @@ -59,32 +60,56 @@ from beam in oldBeam.Sample(beamsCount) beam.Dispose(); beams.RemoveAt(0); } + + // Normalize all remaining beam probabilties. + NormalizeBeams(beams); } // Print out all remaining beams AnsiConsole.MarkupLineInterpolated($"Final Beams:"); beams.Reverse(); foreach (var beam in beams) - AnsiConsole.MarkupLineInterpolated($"[green](prob:{beam.CumulativeProbability:P10})[/]: {beam}"); + { + AnsiConsole.MarkupLineInterpolated($"[yellow]Probability: {beam.CumulativeProbability:P10}[/]"); + AnsiConsole.MarkupLineInterpolated($"[yellow]Sequence: {string.Join(",", beam.Sequence)}[/]"); + AnsiConsole.MarkupLineInterpolated($"[green]{beam}[/]"); + Console.WriteLine(); + } Console.WriteLine("Press any key to exit demo"); Console.ReadKey(true); } + /// + /// As the beam grows the cumulative probability gets very small. Normalizing all the beams prevents the value collapsing to zero. + /// + /// + private static void NormalizeBeams(List beams) + { + // Find max probability + var max = beams.MaxBy(a => a.CumulativeProbability)!.CumulativeProbability; + + // Divide all beams by max, this makes the max prob = 1.0 + foreach (var beam in beams) + beam.CumulativeProbability /= max; + } + private class Beam : IDisposable { public readonly Conversation Conversation; - public readonly double CumulativeProbability; public readonly IReadOnlyList Tokens; public readonly IReadOnlyList Sequence; + public double CumulativeProbability; + public Beam(Conversation conversation, double prob, IReadOnlyList tokens, IReadOnlyList sequence) { Conversation = conversation; - CumulativeProbability = prob; Tokens = tokens; Sequence = sequence; + + CumulativeProbability = prob; } public void Dispose() @@ -102,16 +127,24 @@ public List Sample(int nbeams) var results = new List(); for (var i = 0; i < nbeams; i++) { + // After softmax the logits array is in descending order of probability. Take the first `nbeams` items to make new beams. var item = logitsArr.Data.Span[i]; + // Fork the parent conversation. This shares all of the KV cache with the parent (and other forks) + // so does not cost any extra memory. var c = Conversation.Fork(); + + // Extend the conversation with the selected token. c.Prompt(item.id); + // Keep track of the cumulative probability of this entire sequence. var p = CumulativeProbability * item.p; + // Keep track of all tokens in this sequence, for decoding later var t = Tokens.ToList(); t.Add(item.id); + // Keep track of which beam this beam was derived from. var s = Sequence.ToList(); s.Add(c.ConversationId); @@ -125,10 +158,9 @@ public List Sample(int nbeams) public override string ToString() { -#pragma warning disable CS0618 // Type or member is obsolete - return Conversation.Executor.Context.DeTokenize(Tokens); -#pragma warning restore CS0618 // Type or member is obsolete - + var decoder = new StreamingTokenDecoder(Conversation.Executor.Context); + decoder.AddRange(Tokens); + return decoder.Read(); } } } \ No newline at end of file From a5f61a1e41056efe1d6e7db2fcf245092db987ae Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 20 Jun 2024 11:58:10 +0100 Subject: [PATCH 4/4] Fixed spelling --- LLama.Examples/Examples/BatchedExecutorBeamSearch.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs index f61d2ee89..ce91fccc3 100644 --- a/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs +++ b/LLama.Examples/Examples/BatchedExecutorBeamSearch.cs @@ -11,7 +11,7 @@ namespace LLama.Examples.Examples; /// Beam search is a technique for finding the most likely multi-token completion from a prompt. The search keeps track of a /// set of "beams", each beam is a possible completion and keeps track of it's cumulative probability. At each step all /// of the current beams are split into multiple beams by extending the beam with different possible tokens (greedy sampling the -/// top N tokens), the set of _all_ beams is then trimmed down to just the most likely beams. This allows multiple possibilties to +/// top N tokens), the set of _all_ beams is then trimmed down to just the most likely beams. This allows multiple possibilities to /// be considered, and can find a higher probability result than simply greedy sampling the most likely token at every stage. /// public class BatchedExecutorBeamSearch @@ -61,7 +61,7 @@ from beam in oldBeam.Sample(beamsCount) beams.RemoveAt(0); } - // Normalize all remaining beam probabilties. + // Normalize all remaining beam probabilities. NormalizeBeams(beams); }