diff --git a/app/proprietary/build.gradle b/app/proprietary/build.gradle index e20fa095e..c11152810 100644 --- a/app/proprietary/build.gradle +++ b/app/proprietary/build.gradle @@ -52,8 +52,7 @@ dependencies { api 'org.springframework.boot:spring-boot-starter-cache' api 'com.github.ben-manes.caffeine:caffeine' api 'io.swagger.core.v3:swagger-core-jakarta:2.2.38' - implementation 'org.springframework.ai:spring-ai-spring-boot-starter:1.0.1' - implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:1.0.1' + implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:1.0.0-M6' implementation 'com.bucket4j:bucket4j_jdk17-core:8.15.0' // https://mvnrepository.com/artifact/com.bucket4j/bucket4j_jdk17 diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotConversationService.java b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotConversationService.java index a98b8ebc5..7e8a637ee 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotConversationService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotConversationService.java @@ -146,9 +146,9 @@ public class ChatbotConversationService { if (modelSwitchVerified.compareAndSet(false, true)) { ChatbotSettings settings = featureProperties.current(); OpenAiChatOptions primary = - OpenAiChatOptions.builder().withModel(settings.models().primary()).build(); + OpenAiChatOptions.builder().model(settings.models().primary()).build(); OpenAiChatOptions fallback = - OpenAiChatOptions.builder().withModel(settings.models().fallback()).build(); + OpenAiChatOptions.builder().model(settings.models().fallback()).build(); log.info( "Verified runtime model override support ({} -> {})", primary.getModel(), @@ -185,7 +185,7 @@ public class ChatbotConversationService { Optional.ofNullable(response) .map(ChatResponse::getResults) .filter(results -> !results.isEmpty()) - .map(results -> results.get(0).getOutput().getContent()) + .map(results -> results.get(0).getOutput().getText()) .orElse(""); return parseModelResponse(content); } @@ -227,7 +227,7 @@ public class ChatbotConversationService { + question; OpenAiChatOptions options = - OpenAiChatOptions.builder().withModel(model).withTemperature(0.2).build(); + OpenAiChatOptions.builder().model(model).temperature(0.2).build(); return new Prompt( List.of(new SystemMessage(systemPrompt), new UserMessage(userPrompt)), options); diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotIngestionService.java b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotIngestionService.java index 2ee402bbc..5cda9f34c 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotIngestionService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotIngestionService.java @@ -6,7 +6,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; -import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; @@ -29,7 +29,7 @@ public class ChatbotIngestionService { private final ChatbotCacheService cacheService; private final ChatbotSessionRegistry sessionRegistry; private final ChatbotFeatureProperties featureProperties; - private final EmbeddingClient embeddingClient; + private final EmbeddingModel embeddingModel; public ChatbotSession ingest(ChatbotSessionCreateRequest request) { ChatbotSettings settings = featureProperties.current(); @@ -122,19 +122,24 @@ public class ChatbotIngestionService { if (chunkTexts.isEmpty()) { throw new ChatbotException("Unable to split document text into retrievable chunks"); } - EmbeddingResponse response = embeddingClient.embedForResponse(chunkTexts); - if (response.getData().size() != chunkTexts.size()) { + EmbeddingResponse response = embeddingModel.embedForResponse(chunkTexts); + if (response.getResults().size() != chunkTexts.size()) { throw new ChatbotException("Mismatch between chunks and embedding results"); } List chunks = new ArrayList<>(); for (int i = 0; i < chunkTexts.size(); i++) { String chunkId = sessionId + ":" + i + ":" + UUID.randomUUID(); + float[] embeddingArray = response.getResults().get(i).getOutput(); + List embedding = new ArrayList<>(embeddingArray.length); + for (float value : embeddingArray) { + embedding.add((double) value); + } chunks.add( ChatbotTextChunk.builder() .id(chunkId) .order(i) .text(chunkTexts.get(i)) - .embedding(response.getData().get(i).getEmbedding()) + .embedding(embedding) .build()); } log.debug( diff --git a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotRetrievalService.java b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotRetrievalService.java index 4626e196d..c2a35620d 100644 --- a/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotRetrievalService.java +++ b/app/proprietary/src/main/java/stirling/software/proprietary/service/chatbot/ChatbotRetrievalService.java @@ -5,7 +5,7 @@ import java.util.Comparator; import java.util.List; import java.util.Optional; -import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -24,7 +24,7 @@ import stirling.software.proprietary.service.chatbot.exception.ChatbotException; public class ChatbotRetrievalService { private final ChatbotCacheService cacheService; - private final EmbeddingClient embeddingClient; + private final EmbeddingModel embeddingModel; public List retrieveTopK( String sessionId, String query, ChatbotSettings settings) { @@ -54,10 +54,17 @@ public class ChatbotRetrievalService { } private List computeQueryEmbedding(String query) { - EmbeddingResponse response = embeddingClient.embedForResponse(List.of(query)); - return Optional.ofNullable(response.getData().stream().findFirst().orElse(null)) - .map(org.springframework.ai.embedding.Embedding::getEmbedding) - .orElseThrow(() -> new ChatbotException("Failed to compute query embedding")); + EmbeddingResponse response = embeddingModel.embedForResponse(List.of(query)); + float[] embeddingArray = + Optional.ofNullable(response.getResults().stream().findFirst().orElse(null)) + .map(org.springframework.ai.embedding.Embedding::getOutput) + .orElseThrow( + () -> new ChatbotException("Failed to compute query embedding")); + List embedding = new ArrayList<>(embeddingArray.length); + for (float value : embeddingArray) { + embedding.add((double) value); + } + return embedding; } private double cosineSimilarity(List v1, List v2) { diff --git a/build.gradle b/build.gradle index d9220c9f5..034486798 100644 --- a/build.gradle +++ b/build.gradle @@ -15,7 +15,7 @@ import com.github.jk1.license.render.* ext { springBootVersion = "3.5.6" - springAiVersion = "1.0.1" + springAiVersion = "1.0.0-M6" pdfboxVersion = "3.0.5" imageioVersion = "3.12.0" lombokVersion = "1.18.42" @@ -54,7 +54,6 @@ springBoot { repositories { mavenCentral() maven { url = 'https://build.shibboleth.net/maven/releases' } - maven { url = 'https://repo.spring.io/release' } } allprojects { @@ -96,6 +95,7 @@ subprojects { repositories { mavenCentral() maven { url = 'https://repo.spring.io/release' } + maven { url 'https://repo.spring.io/milestone' } } configurations.configureEach {