refactor: rewrite OpenAI integration

- use `OpenAI` and `Azure.AI.OpenAI`
- use streaming response
This commit is contained in:
leo 2024-12-24 15:51:27 +08:00
parent c9b00d7bfe
commit 2ea3236645
No known key found for this signature in database
6 changed files with 72 additions and 147 deletions

View file

@ -137,11 +137,11 @@ This software supports using OpenAI or other AI service that has an OpenAI comap
For `OpenAI`:
* `Server` must be `https://api.openai.com/v1/chat/completions`
* `Server` must be `https://api.openai.com/v1`
For other AI service:
* The `Server` should fill in a URL equivalent to OpenAI's `https://api.openai.com/v1/chat/completions`. For example, when using `Ollama`, it should be `http://localhost:11434/v1/chat/completions` instead of `http://localhost:11434/api/generate`
* The `Server` should fill in a URL equivalent to OpenAI's `https://api.openai.com/v1`. For example, when using `Ollama`, it should be `http://localhost:11434/v1` instead of `http://localhost:11434/api/generate`
* The `API Key` is optional that depends on the service
## External Tools

View file

@ -46,8 +46,6 @@ namespace SourceGit
[JsonSerializable(typeof(Models.ExternalToolPaths))]
[JsonSerializable(typeof(Models.InteractiveRebaseJobCollection))]
[JsonSerializable(typeof(Models.JetBrainsState))]
[JsonSerializable(typeof(Models.OpenAIChatRequest))]
[JsonSerializable(typeof(Models.OpenAIChatResponse))]
[JsonSerializable(typeof(Models.ThemeOverrides))]
[JsonSerializable(typeof(Models.Version))]
[JsonSerializable(typeof(Models.RepositorySettings))]

View file

@ -20,76 +20,76 @@ namespace SourceGit.Commands
}
}
public GenerateCommitMessage(Models.OpenAIService service, string repo, List<Models.Change> changes, CancellationToken cancelToken, Action<string> onProgress)
public GenerateCommitMessage(Models.OpenAIService service, string repo, List<Models.Change> changes, CancellationToken cancelToken, Action<string> onProgress, Action<string> onResponse)
{
_service = service;
_repo = repo;
_changes = changes;
_cancelToken = cancelToken;
_onProgress = onProgress;
_onResponse = onResponse;
}
public string Result()
public void Exec()
{
try
{
var summarybuilder = new StringBuilder();
var summaryBuilder = new StringBuilder();
var bodyBuilder = new StringBuilder();
_onResponse?.Invoke("Wait for all file analysis to complete...");
foreach (var change in _changes)
{
if (_cancelToken.IsCancellationRequested)
return "";
return;
_onProgress?.Invoke($"Analyzing {change.Path}...");
var summary = GenerateChangeSummary(change);
summarybuilder.Append("- ");
summarybuilder.Append(summary);
summarybuilder.Append("(file: ");
summarybuilder.Append(change.Path);
summarybuilder.Append(")");
summarybuilder.AppendLine();
_onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{bodyBuilder}");
bodyBuilder.Append("- ");
bodyBuilder.Append(summary);
bodyBuilder.AppendLine();
summaryBuilder.Append("- ");
GenerateChangeSummary(change, summaryBuilder, bodyBuilder);
bodyBuilder.Append("\n");
summaryBuilder.Append("(file: ");
summaryBuilder.Append(change.Path);
summaryBuilder.Append(")\n");
}
if (_cancelToken.IsCancellationRequested)
return "";
return;
_onProgress?.Invoke($"Generating commit message...");
var body = bodyBuilder.ToString();
var subject = GenerateSubject(summarybuilder.ToString());
return string.Format("{0}\n\n{1}", subject, body);
GenerateSubject(summaryBuilder.ToString(), body);
}
catch (Exception e)
{
App.RaiseException(_repo, $"Failed to generate commit message: {e}");
return "";
}
}
private string GenerateChangeSummary(Models.Change change)
private void GenerateChangeSummary(Models.Change change, StringBuilder summary, StringBuilder body)
{
var rs = new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadToEnd();
var diff = rs.IsSuccess ? rs.StdOut : "unknown change";
var rsp = _service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken);
if (rsp != null && rsp.Choices.Count > 0)
return rsp.Choices[0].Message.Content;
return string.Empty;
_service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken, update =>
{
body.Append(update);
summary.Append(update);
_onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{body}");
});
}
private string GenerateSubject(string summary)
private void GenerateSubject(string summary, string body)
{
var rsp = _service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken);
if (rsp != null && rsp.Choices.Count > 0)
return rsp.Choices[0].Message.Content;
return string.Empty;
StringBuilder result = new StringBuilder();
_service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken, update =>
{
result.Append(update);
_onResponse?.Invoke($"{result}\n\n{body}");
});
}
private Models.OpenAIService _service;
@ -97,5 +97,6 @@ namespace SourceGit.Commands
private List<Models.Change> _changes;
private CancellationToken _cancelToken;
private Action<string> _onProgress;
private Action<string> _onResponse;
}
}

View file

@ -1,81 +1,13 @@
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.ClientModel;
using System.Threading;
using Azure.AI.OpenAI;
using CommunityToolkit.Mvvm.ComponentModel;
using OpenAI;
using OpenAI.Chat;
namespace SourceGit.Models
{
public class OpenAIChatMessage
{
[JsonPropertyName("role")]
public string Role
{
get;
set;
}
[JsonPropertyName("content")]
public string Content
{
get;
set;
}
}
public class OpenAIChatChoice
{
[JsonPropertyName("index")]
public int Index
{
get;
set;
}
[JsonPropertyName("message")]
public OpenAIChatMessage Message
{
get;
set;
}
}
public class OpenAIChatResponse
{
[JsonPropertyName("choices")]
public List<OpenAIChatChoice> Choices
{
get;
set;
} = [];
}
public class OpenAIChatRequest
{
[JsonPropertyName("model")]
public string Model
{
get;
set;
}
[JsonPropertyName("messages")]
public List<OpenAIChatMessage> Messages
{
get;
set;
} = [];
public void AddMessage(string role, string content)
{
Messages.Add(new OpenAIChatMessage { Role = role, Content = content });
}
}
public class OpenAIService : ObservableObject
{
public string Name
@ -147,44 +79,38 @@ namespace SourceGit.Models
""";
}
public OpenAIChatResponse Chat(string prompt, string question, CancellationToken cancellation)
{
var chat = new OpenAIChatRequest() { Model = Model };
chat.AddMessage("user", prompt);
chat.AddMessage("user", question);
var client = new HttpClient() { Timeout = TimeSpan.FromSeconds(60) };
if (!string.IsNullOrEmpty(ApiKey))
public void Chat(string prompt, string question, CancellationToken cancellation, Action<string> onUpdate)
{
Uri server = new(Server);
ApiKeyCredential key = new(ApiKey);
ChatClient client = null;
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal))
client.DefaultRequestHeaders.Add("api-key", ApiKey);
{
var azure = new AzureOpenAIClient(server, key);
client = azure.GetChatClient(Model);
}
else
client.DefaultRequestHeaders.Add("Authorization", $"Bearer {ApiKey}");
{
var openai = new OpenAIClient(key, new() { Endpoint = server });
client = openai.GetChatClient(Model);
}
var req = new StringContent(JsonSerializer.Serialize(chat, JsonCodeGen.Default.OpenAIChatRequest), Encoding.UTF8, "application/json");
try
{
var task = client.PostAsync(Server, req, cancellation);
task.Wait(cancellation);
var updates = client.CompleteChatStreaming([
new UserChatMessage(prompt),
new UserChatMessage(question),
], null, cancellation);
var rsp = task.Result;
var reader = rsp.Content.ReadAsStringAsync(cancellation);
reader.Wait(cancellation);
var body = reader.Result;
if (!rsp.IsSuccessStatusCode)
foreach (var update in updates)
{
throw new Exception($"AI service returns error code {rsp.StatusCode}. Body: {body ?? string.Empty}");
if (update.ContentUpdate.Count > 0)
onUpdate.Invoke(update.ContentUpdate[0].Text);
}
return JsonSerializer.Deserialize(reader.Result, JsonCodeGen.Default.OpenAIChatResponse);
}
catch
{
if (cancellation.IsCancellationRequested)
return null;
if (!cancellation.IsCancellationRequested)
throw;
}
}

View file

@ -24,6 +24,7 @@
<PublishAot>true</PublishAot>
<PublishTrimmed>true</PublishTrimmed>
<TrimMode>link</TrimMode>
<JsonSerializerIsReflectionEnabledByDefault>true</JsonSerializerIsReflectionEnabledByDefault>
</PropertyGroup>
<PropertyGroup Condition="'$(DisableUpdateDetection)' == 'true'">
@ -52,8 +53,10 @@
<PackageReference Include="Avalonia.Diagnostics" Version="11.2.3" Condition="'$(Configuration)' == 'Debug'" />
<PackageReference Include="Avalonia.AvaloniaEdit" Version="11.1.0" />
<PackageReference Include="AvaloniaEdit.TextMate" Version="11.1.0" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="8.3.2" />
<PackageReference Include="LiveChartsCore.SkiaSharpView.Avalonia" Version="2.0.0-rc4.5" />
<PackageReference Include="OpenAI" Version="2.1.0" />
<PackageReference Include="TextMateSharp" Version="1.0.65" />
<PackageReference Include="TextMateSharp.Grammars" Version="1.0.65" />
</ItemGroup>

View file

@ -36,15 +36,17 @@ namespace SourceGit.Views
Task.Run(() =>
{
var message = new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, SetDescription).Result();
if (_cancel.IsCancellationRequested)
return;
Dispatcher.UIThread.Invoke(() =>
new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, progress =>
{
_onDone?.Invoke(message);
Close();
});
Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = progress);
},
message =>
{
Dispatcher.UIThread.Invoke(() => _onDone?.Invoke(message));
}).Exec();
if (!_cancel.IsCancellationRequested)
Dispatcher.UIThread.Invoke(Close);
}, _cancel.Token);
}
@ -54,11 +56,6 @@ namespace SourceGit.Views
_cancel.Cancel();
}
private void SetDescription(string message)
{
Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = message);
}
private Models.OpenAIService _service;
private string _repo;
private List<Models.Change> _changes;