Optimising caching and retrieval

This commit is contained in:
Dario Ghunney Ware 2025-11-19 16:33:25 +00:00
parent c9cd6404ae
commit d3f4a40f68
11 changed files with 245 additions and 40 deletions

View File

@ -636,7 +636,6 @@ public class ApplicationProperties {
private String primary = "gpt-5-nano";
private String fallback = "gpt-5-mini";
private String embedding = "text-embedding-3-small";
private double temperature = 0.2;
private double topP = 0.95;
private long connectTimeoutMillis = 10000;
private long readTimeoutMillis = 60000;

View File

@ -55,6 +55,7 @@ dependencies {
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
implementation 'org.springframework.ai:spring-ai-starter-model-ollama'
implementation 'org.springframework.ai:spring-ai-starter-vector-store-redis'
implementation 'redis.clients:jedis:5.1.0'
implementation 'com.bucket4j:bucket4j_jdk17-core:8.15.0'
// https://mvnrepository.com/artifact/com.bucket4j/bucket4j_jdk17

View File

@ -0,0 +1,102 @@
package stirling.software.proprietary.config;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.Connection;
import redis.clients.jedis.DefaultJedisClientConfig;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisPooled;
@Configuration
@ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
@Slf4j
public class ChatbotRedisConfig {
@Value("${spring.data.redis.host:localhost}")
private String redisHost;
@Value("${spring.data.redis.port:6379}")
private int redisPort;
@Value("${spring.data.redis.password:}")
private String redisPassword;
@Value("${spring.data.redis.timeout:60000}")
private int redisTimeout;
@Value("${spring.data.redis.ssl.enabled:false}")
private boolean sslEnabled;
@Bean
public JedisPooled jedisPooled() {
try {
log.info("Creating JedisPooled connection to {}:{}", redisHost, redisPort);
// Create pool configuration
GenericObjectPoolConfig<Connection> poolConfig = new GenericObjectPoolConfig<>();
poolConfig.setMaxTotal(50);
poolConfig.setMaxIdle(25);
poolConfig.setMinIdle(5);
poolConfig.setTestOnBorrow(true);
poolConfig.setTestOnReturn(true);
poolConfig.setTestWhileIdle(true);
// Create host and port configuration
HostAndPort hostAndPort = new HostAndPort(redisHost, redisPort);
// Create client configuration with authentication if password is provided
JedisClientConfig clientConfig;
if (redisPassword != null && !redisPassword.trim().isEmpty()) {
clientConfig =
DefaultJedisClientConfig.builder()
.password(redisPassword)
.connectionTimeoutMillis(redisTimeout)
.socketTimeoutMillis(redisTimeout)
.ssl(sslEnabled)
.build();
} else {
clientConfig =
DefaultJedisClientConfig.builder()
.connectionTimeoutMillis(redisTimeout)
.socketTimeoutMillis(redisTimeout)
.ssl(sslEnabled)
.build();
}
// Create JedisPooled with configuration
JedisPooled jedisPooled = new JedisPooled(poolConfig, hostAndPort, clientConfig);
// Test the connection
try {
jedisPooled.ping();
log.info("Successfully connected to Redis at {}:{}", redisHost, redisPort);
} catch (Exception pingException) {
log.warn(
"Redis ping failed at {}:{} - {}. Redis might be unavailable.",
redisHost,
redisPort,
pingException.getMessage());
// Close the pool if ping fails
try {
jedisPooled.close();
} catch (Exception closeException) {
// Ignore close exceptions
}
return null;
}
return jedisPooled;
} catch (Exception e) {
log.error("Failed to create JedisPooled connection", e);
// Return null to fall back to SimpleVectorStore
return null;
}
}
}

View File

@ -4,7 +4,8 @@ import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.redis.RedisVectorStore;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@ -14,9 +15,7 @@ import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.JedisPooled;
@Configuration
@org.springframework.boot.autoconfigure.condition.ConditionalOnProperty(
value = "premium.proFeatures.chatbot.enabled",
havingValue = "true")
@ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
@Slf4j
public class ChatbotVectorStoreConfig {
@ -26,15 +25,11 @@ public class ChatbotVectorStoreConfig {
@Bean
@Primary
public VectorStore chatbotVectorStore(
ObjectProvider<JedisPooled> jedisProvider, EmbeddingModel embeddingModel) {
JedisPooled jedis = jedisProvider.getIfAvailable();
if (jedis != null) {
@Autowired(required = false) JedisPooled jedisPooled, EmbeddingModel embeddingModel) {
if (jedisPooled != null) {
try {
jedis.ping();
log.info("Initialising Redis vector store for chatbot usage");
return RedisVectorStore.builder(jedis, embeddingModel)
return RedisVectorStore.builder(jedisPooled, embeddingModel)
.indexName(DEFAULT_INDEX)
.prefix(DEFAULT_PREFIX)
.initializeSchema(true)

View File

@ -38,24 +38,47 @@ public class ChatbotCacheService {
this.maxDocumentCharacters = cacheSettings.getMaxDocumentCharacters();
long ttlMinutes = Math.max(cacheSettings.getTtlMinutes(), 1);
long maxEntries = Math.max(cacheSettings.getMaxEntries(), 1);
long maxTotalCharacters =
Math.max(cacheSettings.getMaxDocumentCharacters() * maxEntries, 1);
this.documentCache =
Caffeine.newBuilder()
.maximumSize(maxEntries)
.maximumWeight(maxTotalCharacters)
.weigher(
(String key, ChatbotDocumentCacheEntry entry) ->
(int)
Math.min(
entry.getTextCharacters()
+ estimateMetadataWeight(entry),
Integer.MAX_VALUE))
.expireAfterWrite(Duration.ofMinutes(ttlMinutes))
.recordStats()
.build();
log.info(
"Initialised chatbot document cache with maxEntries={} ttlMinutes={} maxChars={}",
"Initialised chatbot document cache with maxEntries={} ttlMinutes={} maxChars={} maxWeight={} characters",
maxEntries,
ttlMinutes,
maxDocumentCharacters);
maxDocumentCharacters,
maxTotalCharacters);
}
public long getMaxDocumentCharacters() {
return maxDocumentCharacters;
}
private long estimateMetadataWeight(ChatbotDocumentCacheEntry entry) {
if (entry == null || entry.getMetadata() == null) {
return 0L;
}
return entry.getMetadata().entrySet().stream()
.mapToLong(e -> safeLength(e.getKey()) + safeLength(e.getValue()))
.sum();
}
private long safeLength(String value) {
return value == null ? 0L : value.length();
}
public String register(
String sessionId,
String documentId,

View File

@ -1,7 +1,5 @@
package stirling.software.proprietary.service.chatbot;
import static stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings.ModelProvider.OLLAMA;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
@ -13,6 +11,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
@ -138,7 +137,7 @@ public class ChatbotConversationService {
memoryService.recordTurn(session, request.getPrompt(), finalReply.answer());
recordHistoryTurn(session, "user", request.getPrompt());
recordHistoryTurn(session, "assistant", finalReply.answer());
maybeSummarizeConversation(settings, session);
summarizeConversation(settings, session);
enforceHistoryRetention(session);
return ChatbotResponse.builder()
@ -163,14 +162,14 @@ public class ChatbotConversationService {
private List<String> buildWarnings(ChatbotSettings settings, ChatbotSession session) {
List<String> warnings = new ArrayList<>();
warnings.add("Chatbot is in alpha behaviour may change.");
warnings.add("Image content is not yet supported in answers.");
if (session.isImageContentDetected()) {
warnings.add(
"Detected document images will be ignored until image support is available.");
warnings.add("Image content is not yet supported.");
}
if (session.isOcrRequested()) {
warnings.add("OCR costs may apply for this session.");
}
return warnings;
}
@ -191,11 +190,13 @@ public class ChatbotConversationService {
metadata.put("promptTokens", reply.promptTokens());
metadata.put("completionTokens", reply.completionTokens());
metadata.put("totalTokens", reply.totalTokens());
return metadata;
}
private void ensureModelSwitchCapability(ChatbotSettings settings) {
ChatbotSettings.ModelProvider provider = settings.models().provider();
switch (provider) {
case OPENAI -> {
if (!(chatModel instanceof OpenAiChatModel)) {
@ -210,6 +211,7 @@ public class ChatbotConversationService {
}
}
}
if (modelSwitchVerified.compareAndSet(false, true)) {
log.info(
"Verified runtime model override support for provider {} ({} -> {})",
@ -241,6 +243,7 @@ public class ChatbotConversationService {
history,
conversationSummary);
ChatResponse response;
try {
response = chatModel.call(requestPrompt);
} catch (org.eclipse.jetty.client.HttpResponseException ex) {
@ -256,8 +259,10 @@ public class ChatbotConversationService {
long promptTokens = 0L;
long completionTokens = 0L;
long totalTokens = 0L;
if (response != null && response.getMetadata() != null) {
org.springframework.ai.chat.metadata.Usage usage = response.getMetadata().getUsage();
Usage usage = response.getMetadata().getUsage();
if (usage != null) {
promptTokens = toLong(usage.getPromptTokens());
completionTokens = toLong(usage.getCompletionTokens());
@ -287,6 +292,7 @@ public class ChatbotConversationService {
List<ChatbotHistoryEntry> history,
String conversationSummary) {
String chunkOutline = buildChunkOutline(context);
String chunkExcerpts = buildChunkExcerpts(context);
String metadataSummary =
metadata.entrySet().stream()
.map(entry -> entry.getKey() + ": " + entry.getValue())
@ -321,6 +327,8 @@ public class ChatbotConversationService {
+ contextSummary
+ "\nContext outline:\n"
+ chunkOutline
+ "\nSelected excerpts:\n"
+ chunkExcerpts
+ "Question: "
+ question;
@ -335,7 +343,7 @@ public class ChatbotConversationService {
String normalizedModel = model == null ? "" : model.toLowerCase();
boolean reasoningModel = normalizedModel.startsWith("gpt-5-");
if (!reasoningModel) {
builder.temperature(settings.models().temperature()).topP(settings.models().topP());
builder.topP(settings.models().topP());
}
return builder.build();
}
@ -361,6 +369,30 @@ public class ChatbotConversationService {
return outline.toString();
}
private String buildChunkExcerpts(List<Document> context) {
if (context == null || context.isEmpty()) {
return "No excerpts available.";
}
StringBuilder excerpts = new StringBuilder();
for (Document chunk : context) {
String order = chunk.getMetadata().getOrDefault("chunkOrder", "?").toString();
String snippet = chunk.getText();
if (!StringUtils.hasText(snippet)) {
continue;
}
String normalized = snippet.replaceAll("\\s+", " ").trim();
int maxExcerpt = 400;
if (normalized.length() > maxExcerpt) {
normalized = normalized.substring(0, maxExcerpt - 3) + "...";
}
excerpts.append("[Chunk ").append(order).append("] ").append(normalized).append("\n");
}
if (!StringUtils.hasText(excerpts)) {
return "Chunks retrieved but no text excerpts available.";
}
return excerpts.toString();
}
private String buildConversationOutline(List<ChatbotHistoryEntry> history) {
if (history == null || history.isEmpty()) {
return "No earlier turns stored for this session.";
@ -488,28 +520,30 @@ public class ChatbotConversationService {
}
}
private void maybeSummarizeConversation(ChatbotSettings settings, ChatbotSession session) {
private void summarizeConversation(ChatbotSettings settings, ChatbotSession session) {
if (conversationStore == null
|| session == null
|| !StringUtils.hasText(session.getSessionId())) {
return;
}
int window = conversationStore.defaultWindow();
long historySize = conversationStore.historyLength(session.getSessionId());
if (historySize < Math.max(window * SUMMARY_TRIGGER_MULTIPLIER, window + 1)) {
return;
}
List<ChatbotHistoryEntry> entries =
conversationStore.getRecentTurns(
session.getSessionId(), conversationStore.retentionWindow());
if (entries.isEmpty() || entries.size() <= window) {
return;
}
int cutoff = entries.size() - window;
List<ChatbotHistoryEntry> summarizable = entries.subList(0, cutoff);
String existingSummary = loadConversationSummary(session.getSessionId());
String updatedSummary =
summarizeHistory(settings, session, summarizable, existingSummary);
String updatedSummary = summarizeHistory(settings, session, summarizable, existingSummary);
if (StringUtils.hasText(updatedSummary)) {
try {
conversationStore.storeSummary(session.getSessionId(), updatedSummary);

View File

@ -27,7 +27,6 @@ public class ChatbotFeatureProperties {
chatbot.getModels().getPrimary(),
chatbot.getModels().getFallback(),
chatbot.getModels().getEmbedding(),
chatbot.getModels().getTemperature(),
chatbot.getModels().getTopP());
return new ChatbotSettings(
chatbot.isEnabled(),
@ -90,7 +89,6 @@ public class ChatbotFeatureProperties {
String primary,
String fallback,
String embedding,
double temperature,
double topP) {}
public record RagSettings(int chunkSizeTokens, int chunkOverlapTokens, int topK) {}

View File

@ -1,6 +1,8 @@
package stirling.software.proprietary.service.chatbot;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
@ -8,6 +10,9 @@ import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -21,6 +26,8 @@ public class ChatbotRetrievalService {
private final ChatbotCacheService cacheService;
private final VectorStore vectorStore;
private final Cache<String, List<Document>> retrievalCache =
Caffeine.newBuilder().maximumSize(200).expireAfterWrite(30, TimeUnit.SECONDS).build();
public List<Document> retrieveTopK(String sessionId, String query, ChatbotSettings settings) {
cacheService
@ -29,14 +36,22 @@ public class ChatbotRetrievalService {
int topK = Math.max(settings.rag().topK(), 1);
String sanitizedQuery = StringUtils.hasText(query) ? query : "";
String filterExpression = "sessionId == '" + escape(sessionId) + "'";
String filterExpression = "metadata.sessionId == '" + escape(sessionId) + "'";
String cacheKey = cacheKey(sessionId, sanitizedQuery, topK);
List<Document> cached = retrievalCache.getIfPresent(cacheKey);
if (cached != null) {
return cached;
}
SearchRequest searchRequest =
SearchRequest.builder()
.query(sanitizedQuery)
.topK(topK)
.filterExpression(filterExpression)
.similarityThreshold(0.7f)
.build();
List<Document> results;
try {
results = vectorStore.similaritySearch(searchRequest);
} catch (RuntimeException ex) {
@ -54,9 +69,12 @@ public class ChatbotRetrievalService {
.limit(topK)
.toList();
if (results.isEmpty()) {
throw new ChatbotException("No context available for this chatbot session");
log.warn("No context available for chatbot session {}", sessionId);
}
return results;
List<Document> immutableResults = List.copyOf(results);
retrievalCache.put(cacheKey, immutableResults);
return immutableResults;
}
private String sanitizeRemoteMessage(String message) {
@ -69,4 +87,8 @@ public class ChatbotRetrievalService {
private String escape(String value) {
return value.replace("'", "\\'");
}
private String cacheKey(String sessionId, String query, int topK) {
return sessionId + "::" + Objects.hash(query, topK);
}
}

View File

@ -0,0 +1,33 @@
# Spring AI OpenAI Configuration
# Uses GPT-5-nano as primary model and GPT-5-mini as fallback (configured in settings.yml)
spring.ai.openai.enabled=true
#spring.ai.openai.api-key=# todo <API-KEY-HERE>
spring.ai.openai.base-url=https://api.openai.com
spring.ai.openai.chat.enabled=true
spring.ai.openai.chat.options.model=gpt-5-nano
# Note: Some models only support default temperature value of 1.0
spring.ai.openai.chat.options.temperature=1.0
# For newer models, use max-completion-tokens instead of max-tokens
spring.ai.openai.chat.options.max-completion-tokens=4000
spring.ai.openai.embedding.enabled=true
spring.ai.openai.embedding.options.model=text-embedding-ada-002
# Increase timeout for OpenAI API calls (default is 10 seconds)
spring.ai.openai.chat.options.connection-timeout=60s
spring.ai.openai.chat.options.read-timeout=60s
spring.ai.openai.embedding.options.connection-timeout=60s
spring.ai.openai.embedding.options.read-timeout=60s
# Spring AI Ollama Configuration (disabled to avoid bean conflicts)
spring.ai.ollama.enabled=false
spring.ai.ollama.base-url=http://localhost:11434
spring.ai.ollama.chat.enabled=false
spring.ai.ollama.chat.options.model=llama3
spring.ai.ollama.chat.options.temperature=1.0
spring.ai.ollama.embedding.enabled=false
spring.ai.ollama.embedding.options.model=nomic-embed-text
spring.data.redis.host=localhost
spring.data.redis.port=6379
spring.data.redis.password=
spring.data.redis.timeout=60000
spring.data.redis.ssl.enabled=false

View File

@ -7,6 +7,7 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.*;
import java.time.Instant;
import java.util.ArrayList;
@ -22,13 +23,13 @@ import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.test.util.ReflectionTestUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import stirling.software.proprietary.model.chatbot.ChatbotHistoryEntry;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
@ExtendWith(MockitoExtension.class)
class ChatbotConversationServiceTest {
@ -72,7 +73,6 @@ class ChatbotConversationServiceTest {
"gpt-5-nano",
"gpt-5-mini",
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),
@ -96,15 +96,15 @@ class ChatbotConversationServiceTest {
when(conversationStore.getRecentTurns("session-1", 10))
.thenReturn(historyEntries(6, "doc-123", "Quarterly Report"));
when(conversationStore.loadSummary("session-1")).thenReturn("previous summary");
when(chatModel.call(any()))
when(chatModel.call(any(Prompt.class)))
.thenReturn(
new ChatResponse(
List.of(new Generation(new AssistantMessage("updated summary")))));
ReflectionTestUtils.invokeMethod(
conversationService, "maybeSummarizeConversation", defaultSettings, session);
conversationService, "summarizeConversation", defaultSettings, session);
verify(chatModel, times(1)).call(any());
verify(chatModel, times(1)).call(any(Prompt.class));
verify(conversationStore).storeSummary("session-1", "updated summary");
verify(conversationStore).trimHistory("session-1", 2);
}
@ -118,9 +118,9 @@ class ChatbotConversationServiceTest {
when(conversationStore.historyLength("session-2")).thenReturn(5L);
ReflectionTestUtils.invokeMethod(
conversationService, "maybeSummarizeConversation", defaultSettings, session);
conversationService, "summarizeConversation", defaultSettings, session);
verify(chatModel, never()).call(any());
verify(chatModel, never()).call(any(org.springframework.ai.chat.prompt.Prompt.class));
verify(conversationStore, never()).storeSummary(anyString(), anyString());
verify(conversationStore, never()).trimHistory(anyString(), anyInt());
}

View File

@ -55,7 +55,6 @@ class ChatbotServiceTest {
"gpt-5-nano",
"gpt-5-mini",
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),
@ -74,7 +73,6 @@ class ChatbotServiceTest {
"gpt-5-nano",
"gpt-5-mini",
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),