refactor: using official OpenAI library for AI integration

Signed-off-by: leo <longshuang@msn.cn>
This commit is contained in:
leo 2024-11-20 21:54:52 +08:00
parent b3ebd84af5
commit efdfd53368
No known key found for this signature in database
5 changed files with 74 additions and 144 deletions

View file

@ -46,8 +46,6 @@ namespace SourceGit
[JsonSerializable(typeof(Models.ExternalToolPaths))] [JsonSerializable(typeof(Models.ExternalToolPaths))]
[JsonSerializable(typeof(Models.InteractiveRebaseJobCollection))] [JsonSerializable(typeof(Models.InteractiveRebaseJobCollection))]
[JsonSerializable(typeof(Models.JetBrainsState))] [JsonSerializable(typeof(Models.JetBrainsState))]
[JsonSerializable(typeof(Models.OpenAIChatRequest))]
[JsonSerializable(typeof(Models.OpenAIChatResponse))]
[JsonSerializable(typeof(Models.ThemeOverrides))] [JsonSerializable(typeof(Models.ThemeOverrides))]
[JsonSerializable(typeof(Models.Version))] [JsonSerializable(typeof(Models.Version))]
[JsonSerializable(typeof(Models.RepositorySettings))] [JsonSerializable(typeof(Models.RepositorySettings))]

View file

@ -20,76 +20,75 @@ namespace SourceGit.Commands
} }
} }
public GenerateCommitMessage(Models.OpenAIService service, string repo, List<Models.Change> changes, CancellationToken cancelToken, Action<string> onProgress) public GenerateCommitMessage(Models.OpenAIService service, string repo, List<Models.Change> changes, CancellationToken cancelToken, Action<string> onProgress, Action<string> onResponse)
{ {
_service = service; _service = service;
_repo = repo; _repo = repo;
_changes = changes; _changes = changes;
_cancelToken = cancelToken; _cancelToken = cancelToken;
_onProgress = onProgress; _onProgress = onProgress;
_onResponse = onResponse;
} }
public string Result() public void Exec()
{ {
try try
{ {
var summarybuilder = new StringBuilder(); var summaryBuilder = new StringBuilder();
var bodyBuilder = new StringBuilder(); var bodyBuilder = new StringBuilder();
_onResponse?.Invoke("Wait for all file analysis to complete...");
foreach (var change in _changes) foreach (var change in _changes)
{ {
if (_cancelToken.IsCancellationRequested) if (_cancelToken.IsCancellationRequested)
return ""; return;
_onProgress?.Invoke($"Analyzing {change.Path}..."); _onProgress?.Invoke($"Analyzing {change.Path}...");
_onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{bodyBuilder}");
var summary = GenerateChangeSummary(change);
summarybuilder.Append("- ");
summarybuilder.Append(summary);
summarybuilder.Append("(file: ");
summarybuilder.Append(change.Path);
summarybuilder.Append(")");
summarybuilder.AppendLine();
bodyBuilder.Append("- "); bodyBuilder.Append("- ");
bodyBuilder.Append(summary); summaryBuilder.Append("- ");
bodyBuilder.AppendLine(); GenerateChangeSummary(change, summaryBuilder, bodyBuilder);
bodyBuilder.Append("\n");
summaryBuilder.Append("(file: ");
summaryBuilder.Append(change.Path);
summaryBuilder.Append(")\n");
} }
if (_cancelToken.IsCancellationRequested) if (_cancelToken.IsCancellationRequested)
return ""; return;
_onProgress?.Invoke($"Generating commit message..."); _onProgress?.Invoke($"Generating commit message...");
var body = bodyBuilder.ToString(); var body = bodyBuilder.ToString();
var subject = GenerateSubject(summarybuilder.ToString()); GenerateSubject(summaryBuilder.ToString(), body);
return string.Format("{0}\n\n{1}", subject, body);
} }
catch (Exception e) catch (Exception e)
{ {
App.RaiseException(_repo, $"Failed to generate commit message: {e}"); App.RaiseException(_repo, $"Failed to generate commit message: {e}");
return "";
} }
} }
private string GenerateChangeSummary(Models.Change change) private void GenerateChangeSummary(Models.Change change, StringBuilder summary, StringBuilder body)
{ {
var rs = new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadToEnd(); var rs = new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadToEnd();
var diff = rs.IsSuccess ? rs.StdOut : "unknown change"; var diff = rs.IsSuccess ? rs.StdOut : "unknown change";
var rsp = _service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken); _service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken, update =>
if (rsp != null && rsp.Choices.Count > 0) {
return rsp.Choices[0].Message.Content; body.Append(update);
summary.Append(update);
return string.Empty; _onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{body}");
});
} }
private string GenerateSubject(string summary) private void GenerateSubject(string summary, string body)
{ {
var rsp = _service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken); StringBuilder result = new StringBuilder();
if (rsp != null && rsp.Choices.Count > 0) _service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken, update =>
return rsp.Choices[0].Message.Content; {
result.Append(update);
return string.Empty; _onResponse?.Invoke($"{result}\n\n{body}");
});
} }
private Models.OpenAIService _service; private Models.OpenAIService _service;
@ -97,5 +96,6 @@ namespace SourceGit.Commands
private List<Models.Change> _changes; private List<Models.Change> _changes;
private CancellationToken _cancelToken; private CancellationToken _cancelToken;
private Action<string> _onProgress; private Action<string> _onProgress;
private Action<string> _onResponse;
} }
} }

View file

@ -1,81 +1,16 @@
using System; using System;
using System.Collections.Generic; using System.ClientModel;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading; using System.Threading;
using Azure.AI.OpenAI;
using CommunityToolkit.Mvvm.ComponentModel; using CommunityToolkit.Mvvm.ComponentModel;
using OpenAI;
using OpenAI.Chat;
namespace SourceGit.Models namespace SourceGit.Models
{ {
public class OpenAIChatMessage
{
[JsonPropertyName("role")]
public string Role
{
get;
set;
}
[JsonPropertyName("content")]
public string Content
{
get;
set;
}
}
public class OpenAIChatChoice
{
[JsonPropertyName("index")]
public int Index
{
get;
set;
}
[JsonPropertyName("message")]
public OpenAIChatMessage Message
{
get;
set;
}
}
public class OpenAIChatResponse
{
[JsonPropertyName("choices")]
public List<OpenAIChatChoice> Choices
{
get;
set;
} = [];
}
public class OpenAIChatRequest
{
[JsonPropertyName("model")]
public string Model
{
get;
set;
}
[JsonPropertyName("messages")]
public List<OpenAIChatMessage> Messages
{
get;
set;
} = [];
public void AddMessage(string role, string content)
{
Messages.Add(new OpenAIChatMessage { Role = role, Content = content });
}
}
public class OpenAIService : ObservableObject public class OpenAIService : ObservableObject
{ {
public string Name public string Name
@ -147,45 +82,43 @@ namespace SourceGit.Models
"""; """;
} }
public OpenAIChatResponse Chat(string prompt, string question, CancellationToken cancellation) public void Chat(string prompt, string question, CancellationToken cancellation, Action<string> onUpdate)
{ {
var chat = new OpenAIChatRequest() { Model = Model }; Uri server = new(Server);
chat.AddMessage("system", prompt); ApiKeyCredential key = new(ApiKey);
chat.AddMessage("user", question); ChatClient client = null;
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal))
var client = new HttpClient() { Timeout = TimeSpan.FromSeconds(60) };
if (!string.IsNullOrEmpty(ApiKey))
{ {
if (Server.Contains("openai.azure.com/", StringComparison.Ordinal)) var azure = new AzureOpenAIClient(server, key);
client.DefaultRequestHeaders.Add("api-key", ApiKey); client = azure.GetChatClient(Model);
else }
client.DefaultRequestHeaders.Add("Authorization", $"Bearer {ApiKey}"); else
{
var openai = new OpenAIClient(key, new()
{
Endpoint = server
});
client = openai.GetChatClient(Model);
} }
var req = new StringContent(JsonSerializer.Serialize(chat, JsonCodeGen.Default.OpenAIChatRequest), Encoding.UTF8, "application/json");
try try
{ {
var task = client.PostAsync(Server, req, cancellation); var updates = client.CompleteChatStreaming([
task.Wait(cancellation); new UserChatMessage(prompt),
new UserChatMessage(question),
], null, cancellation);
var rsp = task.Result; foreach (var update in updates)
var reader = rsp.Content.ReadAsStringAsync(cancellation);
reader.Wait(cancellation);
var body = reader.Result;
if (!rsp.IsSuccessStatusCode)
{ {
throw new Exception($"AI service returns error code {rsp.StatusCode}. Body: {body??string.Empty}"); if (update.ContentUpdate.Count > 0)
onUpdate.Invoke(update.ContentUpdate[0].Text);
} }
return JsonSerializer.Deserialize(reader.Result, JsonCodeGen.Default.OpenAIChatResponse);
} }
catch catch
{ {
if (cancellation.IsCancellationRequested) if (!cancellation.IsCancellationRequested)
return null; throw;
throw;
} }
} }

View file

@ -44,8 +44,10 @@
<PackageReference Include="Avalonia.Diagnostics" Version="11.2.1" Condition="'$(Configuration)' == 'Debug'" /> <PackageReference Include="Avalonia.Diagnostics" Version="11.2.1" Condition="'$(Configuration)' == 'Debug'" />
<PackageReference Include="Avalonia.AvaloniaEdit" Version="11.1.0" /> <PackageReference Include="Avalonia.AvaloniaEdit" Version="11.1.0" />
<PackageReference Include="AvaloniaEdit.TextMate" Version="11.1.0" /> <PackageReference Include="AvaloniaEdit.TextMate" Version="11.1.0" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.0.0" />
<PackageReference Include="CommunityToolkit.Mvvm" Version="8.3.2" /> <PackageReference Include="CommunityToolkit.Mvvm" Version="8.3.2" />
<PackageReference Include="LiveChartsCore.SkiaSharpView.Avalonia" Version="2.0.0-rc4.5" /> <PackageReference Include="LiveChartsCore.SkiaSharpView.Avalonia" Version="2.0.0-rc4.5" />
<PackageReference Include="OpenAI" Version="2.0.0" />
<PackageReference Include="TextMateSharp" Version="1.0.64" /> <PackageReference Include="TextMateSharp" Version="1.0.64" />
<PackageReference Include="TextMateSharp.Grammars" Version="1.0.64" /> <PackageReference Include="TextMateSharp.Grammars" Version="1.0.64" />
</ItemGroup> </ItemGroup>

View file

@ -36,15 +36,17 @@ namespace SourceGit.Views
Task.Run(() => Task.Run(() =>
{ {
var message = new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, SetDescription).Result(); new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, progress =>
if (_cancel.IsCancellationRequested)
return;
Dispatcher.UIThread.Invoke(() =>
{ {
_onDone?.Invoke(message); Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = progress);
Close(); },
}); message =>
{
Dispatcher.UIThread.Invoke(() => _onDone?.Invoke(message));
}).Exec();
if (!_cancel.IsCancellationRequested)
Dispatcher.UIThread.Invoke(Close);
}, _cancel.Token); }, _cancel.Token);
} }
@ -54,11 +56,6 @@ namespace SourceGit.Views
_cancel.Cancel(); _cancel.Cancel();
} }
private void SetDescription(string message)
{
Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = message);
}
private Models.OpenAIService _service; private Models.OpenAIService _service;
private string _repo; private string _repo;
private List<Models.Change> _changes; private List<Models.Change> _changes;