unify chat across docs 2

This commit is contained in:
DarioGii 2025-11-16 00:12:20 +00:00
parent 2fc7eedc45
commit e5e67aff82
4 changed files with 551 additions and 5 deletions

View File

@ -0,0 +1,7 @@
package stirling.software.proprietary.model.chatbot;
import java.time.Instant;
/** Simple record representing a stored chatbot conversation turn. */
public record ChatbotHistoryEntry(
String role, String content, String documentId, String documentName, Instant timestamp) {}

View File

@ -30,6 +30,7 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
import stirling.software.proprietary.model.chatbot.ChatbotHistoryEntry;
import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
@ -42,6 +43,9 @@ import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@RequiredArgsConstructor
public class ChatbotConversationService {
private static final int SUMMARY_TRIGGER_MULTIPLIER = 3;
private static final int SUMMARY_TRANSCRIPT_MAX_CHARS = 4000;
private final ChatModel chatModel;
private final ChatbotSessionRegistry sessionRegistry;
private final ChatbotCacheService cacheService;
@ -50,6 +54,7 @@ public class ChatbotConversationService {
private final ChatbotContextCompressor contextCompressor;
private final ChatbotMemoryService memoryService;
private final ChatbotUsageService usageService;
private final ChatbotConversationStore conversationStore;
private final ObjectMapper objectMapper;
private final AtomicBoolean modelSwitchVerified = new AtomicBoolean(false);
@ -84,6 +89,9 @@ public class ChatbotConversationService {
String contextSummary =
contextCompressor.summarize(
context, (int) Math.max(settings.maxPromptCharacters() / 2, 1000));
List<ChatbotHistoryEntry> conversationHistory =
loadConversationHistory(session.getSessionId());
String conversationSummary = loadConversationSummary(session.getSessionId());
ModelReply nanoReply =
invokeModel(
@ -93,7 +101,9 @@ public class ChatbotConversationService {
session,
context,
contextSummary,
cacheEntry.getMetadata());
cacheEntry.getMetadata(),
conversationHistory,
conversationSummary);
boolean shouldEscalate =
request.isAllowEscalation()
@ -113,7 +123,9 @@ public class ChatbotConversationService {
session,
context,
contextSummary,
cacheEntry.getMetadata());
cacheEntry.getMetadata(),
conversationHistory,
conversationSummary);
}
ChatbotUsageSummary usageSummary =
@ -124,6 +136,10 @@ public class ChatbotConversationService {
session.setUsageSummary(usageSummary);
memoryService.recordTurn(session, request.getPrompt(), finalReply.answer());
recordHistoryTurn(session, "user", request.getPrompt());
recordHistoryTurn(session, "assistant", finalReply.answer());
maybeSummarizeConversation(settings, session);
enforceHistoryRetention(session);
return ChatbotResponse.builder()
.sessionId(request.getSessionId())
@ -210,9 +226,20 @@ public class ChatbotConversationService {
ChatbotSession session,
List<Document> context,
String contextSummary,
Map<String, String> metadata) {
Map<String, String> metadata,
List<ChatbotHistoryEntry> history,
String conversationSummary) {
Prompt requestPrompt =
buildPrompt(settings, model, prompt, session, context, contextSummary, metadata);
buildPrompt(
settings,
model,
prompt,
session,
context,
contextSummary,
metadata,
history,
conversationSummary);
ChatResponse response;
try {
response = chatModel.call(requestPrompt);
@ -256,13 +283,16 @@ public class ChatbotConversationService {
ChatbotSession session,
List<Document> context,
String contextSummary,
Map<String, String> metadata) {
Map<String, String> metadata,
List<ChatbotHistoryEntry> history,
String conversationSummary) {
String chunkOutline = buildChunkOutline(context);
String metadataSummary =
metadata.entrySet().stream()
.map(entry -> entry.getKey() + ": " + entry.getValue())
.reduce((left, right) -> left + ", " + right)
.orElse("none");
String recentTurns = buildConversationOutline(history);
String imageDirective =
session.isImageContentDetected()
@ -281,6 +311,12 @@ public class ChatbotConversationService {
+ session.isOcrRequested()
+ "\n"
+ imageDirective
+ "\nConversation summary:\n"
+ (StringUtils.hasText(conversationSummary)
? conversationSummary
: "No persistent summary available.")
+ "\nRecent conversation turns:\n"
+ recentTurns
+ "\nContext summary:\n"
+ contextSummary
+ "\nContext outline:\n"
@ -325,6 +361,27 @@ public class ChatbotConversationService {
return outline.toString();
}
private String buildConversationOutline(List<ChatbotHistoryEntry> history) {
if (history == null || history.isEmpty()) {
return "No earlier turns stored for this session.";
}
StringBuilder builder = new StringBuilder();
for (ChatbotHistoryEntry entry : history) {
if (entry == null || !StringUtils.hasText(entry.content())) {
continue;
}
builder.append(entry.role()).append(": ").append(entry.content().trim());
if (StringUtils.hasText(entry.documentName())) {
builder.append(" (doc: ").append(entry.documentName()).append(")");
}
builder.append("\n");
}
if (!StringUtils.hasText(builder)) {
return "Conversation history available but empty after filtering.";
}
return builder.toString();
}
private ModelReply parseModelResponse(
String raw, long promptTokens, long completionTokens, long totalTokens) {
if (!StringUtils.hasText(raw)) {
@ -386,4 +443,157 @@ public class ChatbotConversationService {
private long toLong(Integer value) {
return value == null ? 0L : value.longValue();
}
private List<ChatbotHistoryEntry> loadConversationHistory(String sessionId) {
if (conversationStore == null || !StringUtils.hasText(sessionId)) {
return List.of();
}
try {
return conversationStore.getRecentTurns(sessionId, conversationStore.defaultWindow());
} catch (RuntimeException ex) {
log.debug("Conversation history unavailable: {}", ex.getMessage());
return List.of();
}
}
private String loadConversationSummary(String sessionId) {
if (conversationStore == null || !StringUtils.hasText(sessionId)) {
return "";
}
try {
return conversationStore.loadSummary(sessionId);
} catch (RuntimeException ex) {
log.debug("Conversation summary unavailable: {}", ex.getMessage());
return "";
}
}
private void recordHistoryTurn(ChatbotSession session, String role, String content) {
if (conversationStore == null
|| session == null
|| !StringUtils.hasText(session.getSessionId())
|| !StringUtils.hasText(content)) {
return;
}
String documentName =
Optional.ofNullable(session.getMetadata())
.map(meta -> meta.getOrDefault("documentName", ""))
.orElse("");
ChatbotHistoryEntry entry =
conversationStore.createEntry(role, content, session.getDocumentId(), documentName);
try {
conversationStore.appendTurn(session.getSessionId(), entry);
} catch (RuntimeException ex) {
log.debug("Failed to persist chatbot conversation turn: {}", ex.getMessage());
}
}
private void maybeSummarizeConversation(ChatbotSettings settings, ChatbotSession session) {
if (conversationStore == null
|| session == null
|| !StringUtils.hasText(session.getSessionId())) {
return;
}
int window = conversationStore.defaultWindow();
long historySize = conversationStore.historyLength(session.getSessionId());
if (historySize < Math.max(window * SUMMARY_TRIGGER_MULTIPLIER, window + 1)) {
return;
}
List<ChatbotHistoryEntry> entries =
conversationStore.getRecentTurns(
session.getSessionId(), conversationStore.retentionWindow());
if (entries.isEmpty() || entries.size() <= window) {
return;
}
int cutoff = entries.size() - window;
List<ChatbotHistoryEntry> summarizable = entries.subList(0, cutoff);
String existingSummary = loadConversationSummary(session.getSessionId());
String updatedSummary =
summarizeHistory(settings, session, summarizable, existingSummary);
if (StringUtils.hasText(updatedSummary)) {
try {
conversationStore.storeSummary(session.getSessionId(), updatedSummary);
conversationStore.trimHistory(session.getSessionId(), window);
} catch (RuntimeException ex) {
log.debug("Failed to persist chatbot summary: {}", ex.getMessage());
}
}
}
private String summarizeHistory(
ChatbotSettings settings,
ChatbotSession session,
List<ChatbotHistoryEntry> entries,
String existingSummary) {
if (entries == null || entries.isEmpty()) {
return existingSummary;
}
String priorSummary =
StringUtils.hasText(existingSummary)
? existingSummary
: "No previous summary available.";
String transcript = buildSummaryTranscript(entries);
if (!StringUtils.hasText(transcript)) {
return existingSummary;
}
String systemPrompt =
"You maintain a concise running summary of Stirling PDF Bot conversations. "
+ "Capture user goals, referenced documents, and key conclusions in under 200 words.";
String userPrompt =
"Existing summary:\n"
+ priorSummary
+ "\n\nNew conversation turns:\n"
+ transcript
+ "\n\nRespond with the updated summary only.";
Prompt prompt =
new Prompt(
List.of(new SystemMessage(systemPrompt), new UserMessage(userPrompt)),
buildChatOptions(settings, settings.models().primary()));
try {
ChatResponse response = chatModel.call(prompt);
return Optional.ofNullable(response)
.map(ChatResponse::getResults)
.filter(results -> !results.isEmpty())
.map(results -> results.get(0).getOutput().getText())
.map(String::trim)
.filter(StringUtils::hasText)
.orElse(existingSummary);
} catch (RuntimeException ex) {
log.debug("Conversation summarisation failed: {}", ex.getMessage());
return existingSummary;
}
}
private String buildSummaryTranscript(List<ChatbotHistoryEntry> entries) {
StringBuilder builder = new StringBuilder();
for (ChatbotHistoryEntry entry : entries) {
if (entry == null || !StringUtils.hasText(entry.content())) {
continue;
}
if (builder.length() >= SUMMARY_TRANSCRIPT_MAX_CHARS) {
builder.append("\n[conversation truncated]");
break;
}
builder.append(entry.role()).append(": ").append(entry.content().trim());
if (StringUtils.hasText(entry.documentName())) {
builder.append(" (doc: ").append(entry.documentName()).append(")");
}
builder.append("\n");
}
return builder.toString();
}
private void enforceHistoryRetention(ChatbotSession session) {
if (conversationStore == null
|| session == null
|| !StringUtils.hasText(session.getSessionId())) {
return;
}
try {
conversationStore.trimHistory(
session.getSessionId(), conversationStore.retentionWindow());
} catch (RuntimeException ex) {
log.debug("Failed to enforce chatbot history retention: {}", ex.getMessage());
}
}
}

View File

@ -0,0 +1,187 @@
package stirling.software.proprietary.service.chatbot;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.chatbot.ChatbotHistoryEntry;
import redis.clients.jedis.JedisPooled;
/**
* Lightweight Redis-backed conversation store that keeps a short rolling window and summary for
* each chatbot session. This lays the groundwork for richer memory handling without yet impacting
* the main conversation flow.
*/
@Component
@Slf4j
public class ChatbotConversationStore {
private static final String HISTORY_KEY = "chatbot:sessions:%s:history";
private static final String SUMMARY_KEY = "chatbot:sessions:%s:summary";
private static final Duration DEFAULT_TTL = Duration.ofHours(24);
private static final int DEFAULT_WINDOW = 10;
private static final int RETENTION_MULTIPLIER = 5;
private static final int RETENTION_WINDOW = DEFAULT_WINDOW * RETENTION_MULTIPLIER;
private final JedisPooled jedis;
private final ObjectMapper objectMapper;
public ChatbotConversationStore(
ObjectProvider<JedisPooled> jedisProvider, ObjectMapper objectMapper) {
this.jedis = jedisProvider.getIfAvailable();
this.objectMapper = objectMapper;
}
public void appendTurn(String sessionId, ChatbotHistoryEntry entry) {
if (!redisReady() || !StringUtils.hasText(sessionId) || entry == null) {
return;
}
execute(
() -> {
try {
String payload = objectMapper.writeValueAsString(entry);
String key = historyKey(sessionId);
jedis.rpush(key, payload);
jedis.expire(key, (int) DEFAULT_TTL.getSeconds());
jedis.expire(summaryKey(sessionId), (int) DEFAULT_TTL.getSeconds());
} catch (JsonProcessingException ex) {
log.debug("Failed to serialise chatbot turn", ex);
}
});
}
public List<ChatbotHistoryEntry> getRecentTurns(String sessionId, int limit) {
if (!redisReady() || !StringUtils.hasText(sessionId)) {
return Collections.emptyList();
}
return execute(
() -> {
String key = historyKey(sessionId);
long size = jedis.llen(key);
if (size <= 0) {
return Collections.emptyList();
}
long start = Math.max(0, size - Math.max(limit, 1));
List<String> raw = jedis.lrange(key, start, size);
if (CollectionUtils.isEmpty(raw)) {
return Collections.emptyList();
}
List<ChatbotHistoryEntry> entries = new ArrayList<>(raw.size());
for (String chunk : raw) {
try {
entries.add(objectMapper.readValue(chunk, ChatbotHistoryEntry.class));
} catch (JsonProcessingException ex) {
log.debug("Ignoring malformed chatbot history payload", ex);
}
}
return entries;
},
Collections.emptyList());
}
public void trimHistory(String sessionId, int retainEntries) {
if (!redisReady() || !StringUtils.hasText(sessionId) || retainEntries <= 0) {
return;
}
execute(
() -> {
String key = historyKey(sessionId);
jedis.ltrim(key, -retainEntries, -1);
});
}
public void storeSummary(String sessionId, String summary) {
if (!redisReady() || !StringUtils.hasText(sessionId)) {
return;
}
execute(() -> jedis.setex(summaryKey(sessionId), (int) DEFAULT_TTL.getSeconds(), summary));
}
public String loadSummary(String sessionId) {
if (!redisReady() || !StringUtils.hasText(sessionId)) {
return "";
}
return execute(() -> jedis.get(summaryKey(sessionId)), "");
}
public void clear(String sessionId) {
if (!redisReady() || !StringUtils.hasText(sessionId)) {
return;
}
execute(
() -> {
jedis.del(historyKey(sessionId));
jedis.del(summaryKey(sessionId));
});
}
public int defaultWindow() {
return DEFAULT_WINDOW;
}
public int retentionWindow() {
return RETENTION_WINDOW;
}
public long historyLength(String sessionId) {
if (!redisReady() || !StringUtils.hasText(sessionId)) {
return 0L;
}
return execute(() -> jedis.llen(historyKey(sessionId)), 0L);
}
private boolean redisReady() {
return jedis != null;
}
private String historyKey(String sessionId) {
return HISTORY_KEY.formatted(sessionId);
}
private String summaryKey(String sessionId) {
return SUMMARY_KEY.formatted(sessionId);
}
private void execute(Runnable action) {
if (!redisReady()) {
return;
}
try {
action.run();
} catch (RuntimeException ex) {
log.warn("Redis conversation store unavailable: {}", ex.getMessage());
}
}
private <T> T execute(Supplier<T> supplier, T fallback) {
if (!redisReady()) {
return fallback;
}
try {
return supplier.get();
} catch (RuntimeException ex) {
log.warn("Redis conversation store unavailable: {}", ex.getMessage());
return fallback;
}
}
/** Convenience factory to create entries for manual tests. */
public ChatbotHistoryEntry createEntry(
String role, String content, String documentId, String documentName) {
return new ChatbotHistoryEntry(role, content, documentId, documentName, Instant.now());
}
}

View File

@ -0,0 +1,142 @@
package stirling.software.proprietary.service.chatbot;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.test.util.ReflectionTestUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import stirling.software.proprietary.model.chatbot.ChatbotHistoryEntry;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
@ExtendWith(MockitoExtension.class)
class ChatbotConversationServiceTest {
@Mock private ChatModel chatModel;
@Mock private ChatbotSessionRegistry sessionRegistry;
@Mock private ChatbotCacheService cacheService;
@Mock private ChatbotFeatureProperties featureProperties;
@Mock private ChatbotRetrievalService retrievalService;
@Mock private ChatbotContextCompressor contextCompressor;
@Mock private ChatbotMemoryService memoryService;
@Mock private ChatbotUsageService usageService;
@Mock private ChatbotConversationStore conversationStore;
private ChatbotConversationService conversationService;
private ChatbotSettings defaultSettings;
@BeforeEach
void setUp() {
conversationService =
new ChatbotConversationService(
chatModel,
sessionRegistry,
cacheService,
featureProperties,
retrievalService,
contextCompressor,
memoryService,
usageService,
conversationStore,
new ObjectMapper());
defaultSettings =
new ChatbotSettings(
true,
true,
4000,
0.65D,
new ChatbotSettings.ModelSettings(
ChatbotSettings.ModelProvider.OPENAI,
"gpt-5-nano",
"gpt-5-mini",
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),
new ChatbotSettings.OcrSettings(false),
new ChatbotSettings.AuditSettings(false),
new ChatbotSettings.UsageSettings(100_000L, 0.7D));
}
@Test
void summarizesAndTrimsHistoryWhenThresholdReached() {
ChatbotSession session =
ChatbotSession.builder()
.sessionId("session-1")
.documentId("doc-123")
.metadata(Map.of("documentName", "Quarterly Report"))
.build();
when(conversationStore.defaultWindow()).thenReturn(2);
when(conversationStore.retentionWindow()).thenReturn(10);
when(conversationStore.historyLength("session-1")).thenReturn(6L);
when(conversationStore.getRecentTurns("session-1", 10))
.thenReturn(historyEntries(6, "doc-123", "Quarterly Report"));
when(conversationStore.loadSummary("session-1")).thenReturn("previous summary");
when(chatModel.call(any()))
.thenReturn(
new ChatResponse(
List.of(new Generation(new AssistantMessage("updated summary")))));
ReflectionTestUtils.invokeMethod(
conversationService, "maybeSummarizeConversation", defaultSettings, session);
verify(chatModel, times(1)).call(any());
verify(conversationStore).storeSummary("session-1", "updated summary");
verify(conversationStore).trimHistory("session-1", 2);
}
@Test
void skipsSummarizationWhenHistoryBelowThreshold() {
ChatbotSession session =
ChatbotSession.builder().sessionId("session-2").documentId("doc").build();
when(conversationStore.defaultWindow()).thenReturn(4);
when(conversationStore.historyLength("session-2")).thenReturn(5L);
ReflectionTestUtils.invokeMethod(
conversationService, "maybeSummarizeConversation", defaultSettings, session);
verify(chatModel, never()).call(any());
verify(conversationStore, never()).storeSummary(anyString(), anyString());
verify(conversationStore, never()).trimHistory(anyString(), anyInt());
}
private List<ChatbotHistoryEntry> historyEntries(
int count, String documentId, String documentName) {
List<ChatbotHistoryEntry> entries = new ArrayList<>();
for (int i = 0; i < count; i++) {
entries.add(
new ChatbotHistoryEntry(
i % 2 == 0 ? "user" : "assistant",
"message-" + i,
documentId,
documentName,
Instant.now().minusSeconds(60L - i)));
}
return entries;
}
}