Skip to content

Commit

Permalink
Implemented clone for chain
Browse files Browse the repository at this point in the history
  • Loading branch information
martindevans committed Oct 2, 2024
1 parent d4bed4b commit 5804cbd
Showing 1 changed file with 68 additions and 28 deletions.
96 changes: 68 additions & 28 deletions LLama/Native/SafeLLamaSamplerHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace LLama.Native;
public class SafeLLamaSamplerChainHandle
: SafeLLamaSamplerHandle
{
private readonly Dictionary<IntPtr, SafeLLamaSamplerHandle> _samplers = new();
private readonly List<SafeLLamaSamplerHandle> _samplers = [ ];

/// <summary>
/// Get the number of samplers in this chain
Expand All @@ -34,7 +34,7 @@ public void Add(SafeLLamaSamplerHandle sampler)
{
// Sanity check that the sampler isn't already in this chain
lock (_samplers)
if (_samplers.ContainsKey(sampler.DangerousGetHandle()))
if (_samplers.Contains(sampler))
throw new ArgumentException("Cannot add a sampler to a chain twice");

// Add the sampler to the chain
Expand All @@ -48,7 +48,7 @@ public void Add(SafeLLamaSamplerHandle sampler)
// Store a reference to the handle so we can retrieve it later
lock (_samplers)
{
_samplers[sampler.DangerousGetHandle()] = sampler;
_samplers.Add(sampler);
}

// important: this takes ownership of the sampler object and will free it when llama_sampler_free is called on the chain
Expand Down Expand Up @@ -77,16 +77,18 @@ public SafeLLamaSamplerHandle RemoveAtAndReturn(int index)
// Remove the sampler from the chain, returning the pointer to this handle
var ptr = llama_sampler_chain_remove(this, index);

// Find the handle (by looking up it's pointer) and return it.
// Get the handle at that index and return it
lock (_samplers)
{
var sampler = _samplers[ptr];
_samplers.Remove(ptr);
var sampler = _samplers[index];
_samplers.RemoveAt(index);
Debug.Assert(ptr == sampler.DangerousGetHandle());
return sampler;
}

// This is a tricky method to work with! It can't return a handle, because that would create a second handle to
// a these resources! Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
// This is a tricky method to work with!
// It can't return a handle, because that would create a second handle to these resources!
// Instead it returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern IntPtr llama_sampler_chain_remove(SafeLLamaSamplerHandle chain, int i);
}
Expand All @@ -107,29 +109,60 @@ public SafeLLamaSamplerHandle Get(int index)
// Find the handle (by looking up it's pointer) and return it.
lock (_samplers)
{
var sampler = _samplers[ptr];
var sampler = _samplers[index];
Debug.Assert(ptr == sampler.DangerousGetHandle());
return sampler;
}

// This is a tricky method to work with! It can't return a handle, because that would create a second handle to
// a these resources! Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerHandle chain, int i);
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
// Dispose the chain also disposes all sampler stages. Mark all of them as invalid.
// Disposing the chain automatically disposes all sampler stages. Mark all of them as invalid.
lock (_samplers)
{
foreach (var item in _samplers)
item.Value.SetHandleAsInvalid();
item.SetHandleAsInvalid();
_samplers.Clear();
}

return base.ReleaseHandle();
}

/// <inheritdoc />
public override SafeLLamaSamplerHandle Clone()
{
// Create a new handle to own the clone
var chain = new SafeLLamaSamplerChainHandle();

// Clone the chain and move ownership across to the handle created above
var invalidHandle = base.Clone();
chain.SetHandle(invalidHandle.DangerousGetHandle());
invalidHandle.SetHandleAsInvalid();

// We've got a handle of the right type, but we still need to copy all the other bits. Cloning the chain created a load
// of new resources which we don't have any reference to!
for (var i = 0; i < Count; i++)
{
// Create a handle for this resource
var sampler = new SafeLLamaSamplerHandle(llama_sampler_chain_get(this, i));
chain._samplers.Add(sampler);

// Bump the reference count, to account for the fact this sampler is owned by the chain
var success = false;
sampler.DangerousAddRef(ref success);
}

return chain;
}

#region Native API
// This is a tricky method to work with!
// It can't return a handle, because that would create a second handle to these resources.
// Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i);
#endregion
}

/// <summary>
Expand All @@ -149,6 +182,15 @@ public class SafeLLamaSamplerHandle
/// </summary>
public uint Seed => llama_sampler_get_seed(this);

internal SafeLLamaSamplerHandle(IntPtr ptr)
: base(ptr, true)
{
}

internal SafeLLamaSamplerHandle()
{
}

#region create samplers
/// <summary>
/// Create a new sampler chain
Expand Down Expand Up @@ -416,19 +458,17 @@ protected override bool ReleaseHandle()
return true;
}

//todo: cloning is problematic. Calling Clone would **not** work for a chain, because it would not duplicate all the C# structure.
// Going to need some special handling!
///// <summary>
///// Create a clone of this sampler
///// </summary>
///// <returns></returns>
//public SafeLLamaSamplerHandle Clone()
//{
// return llama_sampler_clone(this);
/// <summary>
/// Create a clone of this sampler
/// </summary>
/// <returns></returns>
public virtual SafeLLamaSamplerHandle Clone()
{
return llama_sampler_clone(this);

// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
// static extern SafeLLamaSamplerHandle llama_sampler_clone(SafeLLamaSamplerHandle chain);
//}
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern SafeLLamaSamplerHandle llama_sampler_clone(SafeLLamaSamplerHandle chain);
}

/// <summary>
/// Apply this sampler to a set of candidates
Expand Down

0 comments on commit 5804cbd

Please sign in to comment.