From 2ea3236645bff3e6892d1252a01d74f4a8a02cbb Mon Sep 17 00:00:00 2001 From: leo Date: Tue, 24 Dec 2024 15:51:27 +0800 Subject: [PATCH] refactor: rewrite OpenAI integration - use `OpenAI` and `Azure.AI.OpenAI` - use streaming response --- README.md | 4 +- src/App.JsonCodeGen.cs | 2 - src/Commands/GenerateCommitMessage.cs | 63 ++++++------- src/Models/OpenAI.cs | 124 ++++++-------------------- src/SourceGit.csproj | 3 + src/Views/AIAssistant.axaml.cs | 23 +++-- 6 files changed, 72 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index f0e6b563..be9cf2c1 100644 --- a/README.md +++ b/README.md @@ -137,11 +137,11 @@ This software supports using OpenAI or other AI service that has an OpenAI comap For `OpenAI`: -* `Server` must be `https://api.openai.com/v1/chat/completions` +* `Server` must be `https://api.openai.com/v1` For other AI service: -* The `Server` should fill in a URL equivalent to OpenAI's `https://api.openai.com/v1/chat/completions`. For example, when using `Ollama`, it should be `http://localhost:11434/v1/chat/completions` instead of `http://localhost:11434/api/generate` +* The `Server` should fill in a URL equivalent to OpenAI's `https://api.openai.com/v1`. For example, when using `Ollama`, it should be `http://localhost:11434/v1` instead of `http://localhost:11434/api/generate` * The `API Key` is optional that depends on the service ## External Tools diff --git a/src/App.JsonCodeGen.cs b/src/App.JsonCodeGen.cs index 70567af5..8daea4f9 100644 --- a/src/App.JsonCodeGen.cs +++ b/src/App.JsonCodeGen.cs @@ -46,8 +46,6 @@ namespace SourceGit [JsonSerializable(typeof(Models.ExternalToolPaths))] [JsonSerializable(typeof(Models.InteractiveRebaseJobCollection))] [JsonSerializable(typeof(Models.JetBrainsState))] - [JsonSerializable(typeof(Models.OpenAIChatRequest))] - [JsonSerializable(typeof(Models.OpenAIChatResponse))] [JsonSerializable(typeof(Models.ThemeOverrides))] [JsonSerializable(typeof(Models.Version))] [JsonSerializable(typeof(Models.RepositorySettings))] diff --git a/src/Commands/GenerateCommitMessage.cs b/src/Commands/GenerateCommitMessage.cs index e4f25f38..dc9c3a40 100644 --- a/src/Commands/GenerateCommitMessage.cs +++ b/src/Commands/GenerateCommitMessage.cs @@ -20,76 +20,76 @@ namespace SourceGit.Commands } } - public GenerateCommitMessage(Models.OpenAIService service, string repo, List changes, CancellationToken cancelToken, Action onProgress) + public GenerateCommitMessage(Models.OpenAIService service, string repo, List changes, CancellationToken cancelToken, Action onProgress, Action onResponse) { _service = service; _repo = repo; _changes = changes; _cancelToken = cancelToken; _onProgress = onProgress; + _onResponse = onResponse; } - public string Result() + public void Exec() { try { - var summarybuilder = new StringBuilder(); + var summaryBuilder = new StringBuilder(); var bodyBuilder = new StringBuilder(); + _onResponse?.Invoke("Wait for all file analysis to complete..."); foreach (var change in _changes) { if (_cancelToken.IsCancellationRequested) - return ""; + return; _onProgress?.Invoke($"Analyzing {change.Path}..."); - - var summary = GenerateChangeSummary(change); - summarybuilder.Append("- "); - summarybuilder.Append(summary); - summarybuilder.Append("(file: "); - summarybuilder.Append(change.Path); - summarybuilder.Append(")"); - summarybuilder.AppendLine(); - + _onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{bodyBuilder}"); + bodyBuilder.Append("- "); - bodyBuilder.Append(summary); - bodyBuilder.AppendLine(); + summaryBuilder.Append("- "); + GenerateChangeSummary(change, summaryBuilder, bodyBuilder); + + bodyBuilder.Append("\n"); + summaryBuilder.Append("(file: "); + summaryBuilder.Append(change.Path); + summaryBuilder.Append(")\n"); } if (_cancelToken.IsCancellationRequested) - return ""; + return; _onProgress?.Invoke($"Generating commit message..."); var body = bodyBuilder.ToString(); - var subject = GenerateSubject(summarybuilder.ToString()); - return string.Format("{0}\n\n{1}", subject, body); + GenerateSubject(summaryBuilder.ToString(), body); } catch (Exception e) { App.RaiseException(_repo, $"Failed to generate commit message: {e}"); - return ""; } } - private string GenerateChangeSummary(Models.Change change) + private void GenerateChangeSummary(Models.Change change, StringBuilder summary, StringBuilder body) { var rs = new GetDiffContent(_repo, new Models.DiffOption(change, false)).ReadToEnd(); var diff = rs.IsSuccess ? rs.StdOut : "unknown change"; - var rsp = _service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken); - if (rsp != null && rsp.Choices.Count > 0) - return rsp.Choices[0].Message.Content; - - return string.Empty; + _service.Chat(_service.AnalyzeDiffPrompt, $"Here is the `git diff` output: {diff}", _cancelToken, update => + { + body.Append(update); + summary.Append(update); + _onResponse?.Invoke($"Wait for all file analysis to complete...\n\n{body}"); + }); } - private string GenerateSubject(string summary) + private void GenerateSubject(string summary, string body) { - var rsp = _service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken); - if (rsp != null && rsp.Choices.Count > 0) - return rsp.Choices[0].Message.Content; - - return string.Empty; + StringBuilder result = new StringBuilder(); + _service.Chat(_service.GenerateSubjectPrompt, $"Here are the summaries changes:\n{summary}", _cancelToken, update => + { + result.Append(update); + _onResponse?.Invoke($"{result}\n\n{body}"); + }); } private Models.OpenAIService _service; @@ -97,5 +97,6 @@ namespace SourceGit.Commands private List _changes; private CancellationToken _cancelToken; private Action _onProgress; + private Action _onResponse; } } diff --git a/src/Models/OpenAI.cs b/src/Models/OpenAI.cs index df67ff66..317ba322 100644 --- a/src/Models/OpenAI.cs +++ b/src/Models/OpenAI.cs @@ -1,81 +1,13 @@ using System; -using System.Collections.Generic; -using System.Net.Http; -using System.Text; -using System.Text.Json; -using System.Text.Json.Serialization; +using System.ClientModel; using System.Threading; - +using Azure.AI.OpenAI; using CommunityToolkit.Mvvm.ComponentModel; +using OpenAI; +using OpenAI.Chat; namespace SourceGit.Models { - public class OpenAIChatMessage - { - [JsonPropertyName("role")] - public string Role - { - get; - set; - } - - [JsonPropertyName("content")] - public string Content - { - get; - set; - } - } - - public class OpenAIChatChoice - { - [JsonPropertyName("index")] - public int Index - { - get; - set; - } - - [JsonPropertyName("message")] - public OpenAIChatMessage Message - { - get; - set; - } - } - - public class OpenAIChatResponse - { - [JsonPropertyName("choices")] - public List Choices - { - get; - set; - } = []; - } - - public class OpenAIChatRequest - { - [JsonPropertyName("model")] - public string Model - { - get; - set; - } - - [JsonPropertyName("messages")] - public List Messages - { - get; - set; - } = []; - - public void AddMessage(string role, string content) - { - Messages.Add(new OpenAIChatMessage { Role = role, Content = content }); - } - } - public class OpenAIService : ObservableObject { public string Name @@ -147,45 +79,39 @@ namespace SourceGit.Models """; } - public OpenAIChatResponse Chat(string prompt, string question, CancellationToken cancellation) + public void Chat(string prompt, string question, CancellationToken cancellation, Action onUpdate) { - var chat = new OpenAIChatRequest() { Model = Model }; - chat.AddMessage("user", prompt); - chat.AddMessage("user", question); - - var client = new HttpClient() { Timeout = TimeSpan.FromSeconds(60) }; - if (!string.IsNullOrEmpty(ApiKey)) + Uri server = new(Server); + ApiKeyCredential key = new(ApiKey); + ChatClient client = null; + if (Server.Contains("openai.azure.com/", StringComparison.Ordinal)) { - if (Server.Contains("openai.azure.com/", StringComparison.Ordinal)) - client.DefaultRequestHeaders.Add("api-key", ApiKey); - else - client.DefaultRequestHeaders.Add("Authorization", $"Bearer {ApiKey}"); + var azure = new AzureOpenAIClient(server, key); + client = azure.GetChatClient(Model); + } + else + { + var openai = new OpenAIClient(key, new() { Endpoint = server }); + client = openai.GetChatClient(Model); } - var req = new StringContent(JsonSerializer.Serialize(chat, JsonCodeGen.Default.OpenAIChatRequest), Encoding.UTF8, "application/json"); try { - var task = client.PostAsync(Server, req, cancellation); - task.Wait(cancellation); + var updates = client.CompleteChatStreaming([ + new UserChatMessage(prompt), + new UserChatMessage(question), + ], null, cancellation); - var rsp = task.Result; - var reader = rsp.Content.ReadAsStringAsync(cancellation); - reader.Wait(cancellation); - - var body = reader.Result; - if (!rsp.IsSuccessStatusCode) + foreach (var update in updates) { - throw new Exception($"AI service returns error code {rsp.StatusCode}. Body: {body ?? string.Empty}"); + if (update.ContentUpdate.Count > 0) + onUpdate.Invoke(update.ContentUpdate[0].Text); } - - return JsonSerializer.Deserialize(reader.Result, JsonCodeGen.Default.OpenAIChatResponse); } catch { - if (cancellation.IsCancellationRequested) - return null; - - throw; + if (!cancellation.IsCancellationRequested) + throw; } } diff --git a/src/SourceGit.csproj b/src/SourceGit.csproj index 8e8c2b3f..b5ca239e 100644 --- a/src/SourceGit.csproj +++ b/src/SourceGit.csproj @@ -24,6 +24,7 @@ true true link + true @@ -52,8 +53,10 @@ + + diff --git a/src/Views/AIAssistant.axaml.cs b/src/Views/AIAssistant.axaml.cs index d81335eb..39279c83 100644 --- a/src/Views/AIAssistant.axaml.cs +++ b/src/Views/AIAssistant.axaml.cs @@ -36,15 +36,17 @@ namespace SourceGit.Views Task.Run(() => { - var message = new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, SetDescription).Result(); - if (_cancel.IsCancellationRequested) - return; - - Dispatcher.UIThread.Invoke(() => + new Commands.GenerateCommitMessage(_service, _repo, _changes, _cancel.Token, progress => { - _onDone?.Invoke(message); - Close(); - }); + Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = progress); + }, + message => + { + Dispatcher.UIThread.Invoke(() => _onDone?.Invoke(message)); + }).Exec(); + + if (!_cancel.IsCancellationRequested) + Dispatcher.UIThread.Invoke(Close); }, _cancel.Token); } @@ -54,11 +56,6 @@ namespace SourceGit.Views _cancel.Cancel(); } - private void SetDescription(string message) - { - Dispatcher.UIThread.Invoke(() => ProgressMessage.Text = message); - } - private Models.OpenAIService _service; private string _repo; private List _changes;