wip - implementing RAG system

This commit is contained in:
Dario Ghunney Ware 2025-11-13 15:38:06 +00:00
parent 1973c55d10
commit c37707d9ad
31 changed files with 765 additions and 308 deletions

View File

@ -612,6 +612,7 @@ public class ApplicationProperties {
private Audit audit = new Audit();
private long maxPromptCharacters = 4000;
private double minConfidenceNano = 0.65;
private Usage usage = new Usage();
@Data
public static class Cache {
@ -626,6 +627,8 @@ public class ApplicationProperties {
private String primary = "gpt-5-nano";
private String fallback = "gpt-5-mini";
private String embedding = "text-embedding-3-small";
private double temperature = 0.2;
private double topP = 0.95;
private long connectTimeoutMillis = 10000;
private long readTimeoutMillis = 60000;
}
@ -646,6 +649,12 @@ public class ApplicationProperties {
public static class Audit {
private boolean enabled = true;
}
@Data
public static class Usage {
private long perUserMonthlyTokens = 200000;
private double warnAtRatio = 0.7;
}
}
}

View File

@ -52,7 +52,7 @@ springdoc.swagger-ui.url=/v1/api-docs
# Spring AI OpenAI Configuration
# Uses GPT-5-nano as primary model and GPT-5-mini as fallback (configured in settings.yml)
spring.ai.openai.enabled=true
spring.ai.openai.api-key=# todo <API-KEY-HERE>
#spring.ai.openai.api-key=# todo <API-KEY-HERE>
spring.ai.openai.base-url=https://api.openai.com
spring.ai.openai.chat.enabled=true
spring.ai.openai.chat.options.model=gpt-5-nano

View File

@ -97,11 +97,13 @@ premium:
cache:
ttlMinutes: 720 # Cache entry lifetime (12h)
maxEntries: 200 # Maximum number of cached documents per node
maxDocumentCharacters: 200000 # Reject uploads exceeding this character count
maxDocumentCharacters: 600000 # Reject uploads exceeding this character count
models:
primary: gpt-5-nano # Default lightweight model
fallback: gpt-5-mini # Escalation model for complex prompts
embedding: text-embedding-3-small # Embedding model for vector store usage
temperature: 0.2 # Sampling temperature for LLM responses
topP: 0.95 # Top-p (nucleus) sampling for LLM responses
rag:
chunkSizeTokens: 512 # Token window used when chunking text
chunkOverlapTokens: 128 # Overlap between successive chunks
@ -112,6 +114,9 @@ premium:
enabled: true # Emit audit records for chatbot activity
maxPromptCharacters: 4000 # Server-side guardrail for incoming prompts
minConfidenceNano: 0.65 # Minimum nano confidence to avoid escalation
usage:
perUserMonthlyTokens: 200000 # Monthly RAG + chat token budget per user
warnAtRatio: 0.7 # Warn users when usage exceeds 70%
enterpriseFeatures:
audit:
enabled: true # Enable audit logging

View File

@ -52,9 +52,9 @@ dependencies {
api 'org.springframework.boot:spring-boot-starter-cache'
api 'com.github.ben-manes.caffeine:caffeine'
api 'io.swagger.core.v3:swagger-core-jakarta:2.2.38'
api 'org.springframework.ai:spring-ai-starter-model-openai'
api 'org.springframework.ai:spring-ai-starter-model-ollama'
api 'org.springframework.ai:spring-ai-redis-store'
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
implementation 'org.springframework.ai:spring-ai-starter-model-ollama'
implementation 'org.springframework.ai:spring-ai-starter-vector-store-redis'
implementation 'com.bucket4j:bucket4j_jdk17-core:8.15.0'
// https://mvnrepository.com/artifact/com.bucket4j/bucket4j_jdk17

View File

@ -0,0 +1,54 @@
package stirling.software.proprietary.config;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.redis.RedisVectorStore;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.JedisPooled;
@Configuration
@org.springframework.boot.autoconfigure.condition.ConditionalOnProperty(
value = "premium.proFeatures.chatbot.enabled",
havingValue = "true")
@Slf4j
public class ChatbotVectorStoreConfig {
private static final String DEFAULT_INDEX = "stirling-chatbot-index";
private static final String DEFAULT_PREFIX = "stirling:chatbot:";
@Bean
@Primary
public VectorStore chatbotVectorStore(
ObjectProvider<JedisPooled> jedisProvider, EmbeddingModel embeddingModel) {
JedisPooled jedis = jedisProvider.getIfAvailable();
if (jedis != null) {
try {
jedis.ping();
log.info("Initialising Redis vector store for chatbot usage");
return RedisVectorStore.builder(jedis, embeddingModel)
.indexName(DEFAULT_INDEX)
.prefix(DEFAULT_PREFIX)
.initializeSchema(true)
.build();
} catch (RuntimeException ex) {
log.warn(
"Redis vector store unavailable ({}). Falling back to SimpleVectorStore.",
sanitize(ex.getMessage()));
}
} else {
log.info("No Redis connection detected; using SimpleVectorStore for chatbot.");
}
return SimpleVectorStore.builder(embeddingModel).build();
}
private String sanitize(String message) {
return message == null ? "unknown error" : message.replaceAll("\\s+", " ").trim();
}
}

View File

@ -21,6 +21,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotResponse;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
import stirling.software.proprietary.model.chatbot.ChatbotSessionResponse;
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
import stirling.software.proprietary.service.chatbot.ChatbotCacheService;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
@ -28,12 +29,10 @@ import stirling.software.proprietary.service.chatbot.ChatbotService;
import stirling.software.proprietary.service.chatbot.ChatbotSessionRegistry;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@RestController
@RequestMapping("/api/v1/internal/chatbot")
@RequiredArgsConstructor
@Slf4j
// @ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
// @ConditionalOnBean(ChatbotService.class)
@RestController
@RequiredArgsConstructor
@RequestMapping("/api/v1/internal/chatbot")
public class ChatbotController {
private final ChatbotService chatbotService;
@ -54,10 +53,12 @@ public class ChatbotController {
.ocrRequested(session.isOcrRequested())
.imageContentDetected(session.isImageContentDetected())
.textCharacters(session.getTextCharacters())
.estimatedTokens(session.getEstimatedTokens())
.maxCachedCharacters(cacheService.getMaxDocumentCharacters())
.createdAt(session.getCreatedAt())
.warnings(sessionWarnings(settings, session))
.metadata(session.getMetadata())
.usageSummary(session.getUsageSummary())
.build();
return ResponseEntity.status(HttpStatus.CREATED).body(response);
}
@ -83,10 +84,12 @@ public class ChatbotController {
.ocrRequested(session.isOcrRequested())
.imageContentDetected(session.isImageContentDetected())
.textCharacters(session.getTextCharacters())
.estimatedTokens(session.getEstimatedTokens())
.maxCachedCharacters(cacheService.getMaxDocumentCharacters())
.createdAt(session.getCreatedAt())
.warnings(sessionWarnings(settings, session))
.metadata(session.getMetadata())
.usageSummary(session.getUsageSummary())
.build();
return ResponseEntity.ok(response);
}
@ -106,7 +109,16 @@ public class ChatbotController {
warnings.add("Only extracted text is sent for analysis.");
if (session != null && session.isOcrRequested()) {
warnings.add("OCR was requested extra processing charges may apply.");
warnings.add("OCR requested uses credits .");
}
if (session != null && session.getUsageSummary() != null) {
ChatbotUsageSummary usage = session.getUsageSummary();
if (usage.isLimitExceeded()) {
warnings.add("Monthly chatbot allocation exceeded requests may be throttled.");
} else if (usage.isNearingLimit()) {
warnings.add("You are approaching the monthly chatbot allocation.");
}
}
return warnings;

View File

@ -2,7 +2,6 @@ package stirling.software.proprietary.model.chatbot;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
@ -20,19 +19,12 @@ public class ChatbotDocumentCacheEntry {
private String sessionId;
private String documentId;
private Map<String, String> metadata;
private String text;
private List<ChatbotTextChunk> chunks;
private boolean ocrApplied;
private boolean imageContentDetected;
private long textCharacters;
private String vectorStoreId;
private Instant storedAt;
public Map<String, String> getMetadata() {
return metadata == null ? Collections.emptyMap() : metadata;
}
public List<ChatbotTextChunk> getChunks() {
return chunks == null ? Collections.emptyList() : chunks;
}
}

View File

@ -26,6 +26,10 @@ public class ChatbotResponse {
private Instant respondedAt;
private List<String> warnings;
private Map<String, Object> metadata;
private long promptTokens;
private long completionTokens;
private long totalTokens;
private ChatbotUsageSummary usageSummary;
public List<String> getWarnings() {
return warnings == null ? Collections.emptyList() : warnings;

View File

@ -21,9 +21,11 @@ public class ChatbotSession {
private boolean alphaWarningRequired;
private boolean imageContentDetected;
private long textCharacters;
private long estimatedTokens;
private String cacheKey;
private String vectorStoreId;
private Instant createdAt;
private ChatbotUsageSummary usageSummary;
public static String randomSessionId() {
return UUID.randomUUID().toString();

View File

@ -23,9 +23,11 @@ public class ChatbotSessionResponse {
private boolean imageContentDetected;
private long maxCachedCharacters;
private long textCharacters;
private long estimatedTokens;
private Instant createdAt;
private List<String> warnings;
private Map<String, String> metadata;
private ChatbotUsageSummary usageSummary;
public List<String> getWarnings() {
return warnings == null ? Collections.emptyList() : warnings;

View File

@ -1,20 +0,0 @@
package stirling.software.proprietary.model.chatbot;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatbotTextChunk {
private String id;
private String text;
private int order;
private List<Double> embedding;
}

View File

@ -0,0 +1,22 @@
package stirling.software.proprietary.model.chatbot;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatbotUsageSummary {
private long allocatedTokens;
private long consumedTokens;
private long remainingTokens;
private double usageRatio;
private boolean nearingLimit;
private boolean limitExceeded;
private long lastIncrementTokens;
private String window;
}

View File

@ -11,7 +11,6 @@ import org.springframework.context.annotation.DependsOn;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.authentication.ProviderManager;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@ -225,9 +224,19 @@ public class SecurityConfiguration {
csrf.ignoringRequestMatchers(
request -> {
String uri = request.getRequestURI();
String contextPath = request.getContextPath();
String trimmedUri =
uri.startsWith(contextPath)
? uri.substring(
contextPath.length())
: uri;
// Ignore CSRF for auth endpoints
if (uri.startsWith("/api/v1/auth/")) {
// Ignore CSRF for auth endpoints + oauth/saml
if (trimmedUri.startsWith("/api/v1/auth/")
|| trimmedUri.startsWith("/oauth2")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.startsWith(
"/login/oauth2/code/")) {
return true;
}
@ -363,7 +372,8 @@ public class SecurityConfiguration {
loginAttemptService,
securityProperties.getOauth2(),
userService,
jwtService))
jwtService,
applicationProperties))
.failureHandler(new CustomOAuth2AuthenticationFailureHandler())
// Add existing Authorities from the database
.userInfoEndpoint(

View File

@ -50,6 +50,7 @@ public class CustomOAuth2AuthenticationSuccessHandler
private final ApplicationProperties.Security.OAUTH2 oauth2Properties;
private final UserService userService;
private final JwtServiceInterface jwtService;
private final ApplicationProperties applicationProperties;
@Override
@Audited(type = AuditEventType.USER_LOGIN, level = AuditLevel.BASIC)
@ -57,6 +58,12 @@ public class CustomOAuth2AuthenticationSuccessHandler
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
System.out.println("[OAuth Success Handler] Starting authentication success handling");
System.out.println("[OAuth Success Handler] Request URL: " + request.getRequestURL());
System.out.println(
"[OAuth Success Handler] Frontend URL from config: "
+ applicationProperties.getSystem().getFrontendUrl());
Object principal = authentication.getPrincipal();
String username = "";
@ -127,10 +134,16 @@ public class CustomOAuth2AuthenticationSuccessHandler
jwtService.generateToken(
authentication, Map.of("authType", AuthenticationType.OAUTH2));
System.out.println(
"[OAuth Success Handler] JWT generated: "
+ (jwt != null ? "YES (length: " + jwt.length() + ")" : "NO"));
// Build context-aware redirect URL based on the original request
String redirectUrl =
buildContextAwareRedirectUrl(request, response, contextPath, jwt);
System.out.println(
"[OAuth Success Handler] Final redirect URL: " + redirectUrl);
response.sendRedirect(redirectUrl);
} else {
// v1: redirect directly to home
@ -170,14 +183,44 @@ public class CustomOAuth2AuthenticationSuccessHandler
String contextPath,
String jwt) {
String redirectPath = resolveRedirectPath(request, contextPath);
String origin =
resolveForwardedOrigin(request)
.orElseGet(
() ->
resolveOriginFromReferer(request)
.orElseGet(() -> buildOriginFromRequest(request)));
System.out.println("[OAuth Success Handler] Resolving origin...");
Optional<String> configuredOrigin = resolveConfiguredFrontendOrigin();
System.out.println(
"[OAuth Success Handler] Configured frontend origin: "
+ configuredOrigin.orElse("NOT SET"));
String origin;
if (configuredOrigin.isPresent()) {
origin = configuredOrigin.get();
System.out.println("[OAuth Success Handler] Using configured origin: " + origin);
} else {
System.out.println(
"[OAuth Success Handler] No configured origin, trying other methods...");
origin =
resolveForwardedOrigin(request)
.orElseGet(
() -> {
System.out.println(
"[OAuth Success Handler] No forwarded origin, trying referer...");
return resolveOriginFromReferer(request)
.orElseGet(
() -> {
System.out.println(
"[OAuth Success Handler] No referer, building from request...");
return buildOriginFromRequest(request);
});
});
System.out.println("[OAuth Success Handler] Resolved origin: " + origin);
}
clearRedirectCookie(response);
return origin + redirectPath + "#access_token=" + jwt;
String finalUrl = origin + redirectPath + "#access_token=" + jwt;
System.out.println("[OAuth Success Handler] Building redirect URL:");
System.out.println(" - Origin: " + origin);
System.out.println(" - Redirect path: " + redirectPath);
System.out.println(" - Final URL: " + finalUrl);
return finalUrl;
}
private String resolveRedirectPath(HttpServletRequest request, String contextPath) {
@ -278,6 +321,26 @@ public class CustomOAuth2AuthenticationSuccessHandler
return origin.toString();
}
private Optional<String> resolveConfiguredFrontendOrigin() {
System.out.println("[OAuth Success Handler] Checking configured frontend URL...");
if (applicationProperties.getSystem() == null) {
System.out.println("[OAuth Success Handler] applicationProperties.getSystem() is NULL");
return Optional.empty();
}
String configured = applicationProperties.getSystem().getFrontendUrl();
System.out.println("[OAuth Success Handler] Frontend URL from config: " + configured);
if (configured == null || configured.isBlank()) {
System.out.println("[OAuth Success Handler] Frontend URL is null or blank");
return Optional.empty();
}
String trimmed = configured.trim();
if (trimmed.endsWith("/")) {
trimmed = trimmed.substring(0, trimmed.length() - 1);
}
System.out.println("[OAuth Success Handler] Returning configured frontend URL: " + trimmed);
return Optional.of(trimmed);
}
private boolean isDefaultPort(String scheme, String port) {
if (port == null) {
return true;

View File

@ -2,7 +2,6 @@ package stirling.software.proprietary.service.chatbot;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@ -21,15 +20,14 @@ import stirling.software.common.model.ApplicationProperties.Premium;
import stirling.software.common.model.ApplicationProperties.Premium.ProFeatures;
import stirling.software.common.model.ApplicationProperties.Premium.ProFeatures.Chatbot;
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@Service
// @ConditionalOnProperty(value = "premium.proFeatures.chatbot.enabled", havingValue = "true")
@Slf4j
public class ChatbotCacheService {
private final Cache<String, ChatbotDocumentCacheEntry> documentCache;
private final Cache<String, ChatbotDocumentCacheEntry>
documentCache; // todo: can redis be used instead?
private final long maxDocumentCharacters;
private final Map<String, String> sessionToCacheKey = new ConcurrentHashMap<>();
@ -61,18 +59,12 @@ public class ChatbotCacheService {
public String register(
String sessionId,
String documentId,
String rawText,
Map<String, String> metadata,
boolean ocrApplied,
boolean imageContentDetected,
long textCharacters) {
Objects.requireNonNull(sessionId, "sessionId must not be null");
Objects.requireNonNull(documentId, "documentId must not be null");
Objects.requireNonNull(rawText, "rawText must not be null");
if (rawText.length() > maxDocumentCharacters) {
throw new ChatbotException(
"Document text exceeds maximum allowed characters: " + maxDocumentCharacters);
}
String cacheKey =
sessionToCacheKey.computeIfAbsent(sessionId, k -> UUID.randomUUID().toString());
ChatbotDocumentCacheEntry entry =
@ -81,7 +73,6 @@ public class ChatbotCacheService {
.sessionId(sessionId)
.documentId(documentId)
.metadata(metadata)
.text(rawText)
.ocrApplied(ocrApplied)
.imageContentDetected(imageContentDetected)
.textCharacters(textCharacters)
@ -91,17 +82,6 @@ public class ChatbotCacheService {
return cacheKey;
}
public void attachChunks(String cacheKey, List<ChatbotTextChunk> chunks) {
documentCache
.asMap()
.computeIfPresent(
cacheKey,
(key, existing) -> {
existing.setChunks(chunks);
return existing;
});
}
public Optional<ChatbotDocumentCacheEntry> resolveByCacheKey(String cacheKey) {
return Optional.ofNullable(documentCache.getIfPresent(cacheKey));
}

View File

@ -16,6 +16,7 @@ import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
@ -32,12 +33,12 @@ import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@Service
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatbotConversationService {
@ -46,6 +47,7 @@ public class ChatbotConversationService {
private final ChatbotCacheService cacheService;
private final ChatbotFeatureProperties featureProperties;
private final ChatbotRetrievalService retrievalService;
private final ChatbotUsageService usageService;
private final ObjectMapper objectMapper;
private final AtomicBoolean modelSwitchVerified = new AtomicBoolean(false);
@ -74,7 +76,7 @@ public class ChatbotConversationService {
List<String> warnings = buildWarnings(settings, session);
List<ChatbotTextChunk> context =
List<Document> context =
retrievalService.retrieveTopK(
request.getSessionId(), request.getPrompt(), settings);
@ -97,17 +99,23 @@ public class ChatbotConversationService {
boolean escalated = false;
if (shouldEscalate) {
escalated = true;
List<ChatbotTextChunk> expandedContext = ensureMinimumContext(context, cacheEntry);
finalReply =
invokeModel(
settings,
settings.models().fallback(),
request.getPrompt(),
session,
expandedContext,
context,
cacheEntry.getMetadata());
}
ChatbotUsageSummary usageSummary =
usageService.registerGeneration(
session.getUserId(),
finalReply.promptTokens(),
finalReply.completionTokens());
session.setUsageSummary(usageSummary);
return ChatbotResponse.builder()
.sessionId(request.getSessionId())
.modelUsed(
@ -120,6 +128,10 @@ public class ChatbotConversationService {
.respondedAt(Instant.now())
.warnings(warnings)
.metadata(buildMetadata(settings, session, finalReply, context.size(), escalated))
.promptTokens(finalReply.promptTokens())
.completionTokens(finalReply.completionTokens())
.totalTokens(finalReply.totalTokens())
.usageSummary(usageSummary)
.build();
}
@ -151,6 +163,9 @@ public class ChatbotConversationService {
metadata.put("modelProvider", settings.models().provider().name());
metadata.put("imageContentDetected", session.isImageContentDetected());
metadata.put("charactersCached", session.getTextCharacters());
metadata.put("promptTokens", reply.promptTokens());
metadata.put("completionTokens", reply.completionTokens());
metadata.put("totalTokens", reply.totalTokens());
return metadata;
}
@ -179,29 +194,12 @@ public class ChatbotConversationService {
}
}
private List<ChatbotTextChunk> ensureMinimumContext(
List<ChatbotTextChunk> context, ChatbotDocumentCacheEntry entry) {
if (context.size() >= 3 || entry.getChunks().size() <= context.size()) {
return context;
}
List<ChatbotTextChunk> augmented = new ArrayList<>(context);
for (ChatbotTextChunk chunk : entry.getChunks()) {
if (augmented.size() >= 3) {
break;
}
if (!augmented.contains(chunk)) {
augmented.add(chunk);
}
}
return augmented;
}
private ModelReply invokeModel(
ChatbotSettings settings,
String model,
String prompt,
ChatbotSession session,
List<ChatbotTextChunk> context,
List<Document> context,
Map<String, String> metadata) {
Prompt requestPrompt = buildPrompt(settings, model, prompt, session, context, metadata);
ChatResponse response;
@ -217,13 +215,27 @@ public class ChatbotConversationService {
+ sanitizeRemoteMessage(ex.getMessage()),
ex);
}
long promptTokens = 0L;
long completionTokens = 0L;
long totalTokens = 0L;
if (response != null && response.getMetadata() != null) {
org.springframework.ai.chat.metadata.Usage usage = response.getMetadata().getUsage();
if (usage != null) {
promptTokens = toLong(usage.getPromptTokens());
completionTokens = toLong(usage.getCompletionTokens());
totalTokens =
usage.getTotalTokens() != null
? usage.getTotalTokens()
: promptTokens + completionTokens;
}
}
String content =
Optional.ofNullable(response)
.map(ChatResponse::getResults)
.filter(results -> !results.isEmpty())
.map(results -> results.get(0).getOutput().getText())
.orElse("");
return parseModelResponse(content);
return parseModelResponse(content, promptTokens, completionTokens, totalTokens);
}
private Prompt buildPrompt(
@ -231,13 +243,13 @@ public class ChatbotConversationService {
String model,
String question,
ChatbotSession session,
List<ChatbotTextChunk> context,
List<Document> context,
Map<String, String> metadata) {
StringBuilder contextBuilder = new StringBuilder();
for (ChatbotTextChunk chunk : context) {
for (Document chunk : context) {
contextBuilder
.append("[Chunk ")
.append(chunk.getOrder())
.append(chunk.getMetadata().getOrDefault("chunkOrder", "?"))
.append("]\n")
.append(chunk.getText())
.append("\n\n");
@ -270,18 +282,24 @@ public class ChatbotConversationService {
+ "Question: "
+ question;
OpenAiChatOptions options = buildChatOptions(model);
OpenAiChatOptions options = buildChatOptions(settings, model);
return new Prompt(
List.of(new SystemMessage(systemPrompt), new UserMessage(userPrompt)), options);
}
private OpenAiChatOptions buildChatOptions(String model) {
// Note: Some models only support default temperature value of 1.0
return OpenAiChatOptions.builder().model(model).temperature(1.0).build();
private OpenAiChatOptions buildChatOptions(ChatbotSettings settings, String model) {
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder().model(model);
String normalizedModel = model == null ? "" : model.toLowerCase();
boolean reasoningModel = normalizedModel.startsWith("gpt-5-");
if (!reasoningModel) {
builder.temperature(settings.models().temperature()).topP(settings.models().topP());
}
return builder.build();
}
private ModelReply parseModelResponse(String raw) {
private ModelReply parseModelResponse(
String raw, long promptTokens, long completionTokens, long totalTokens) {
if (!StringUtils.hasText(raw)) {
throw new ChatbotException("Model returned empty response");
}
@ -301,15 +319,35 @@ public class ChatbotConversationService {
Optional.ofNullable(node.get("rationale"))
.map(JsonNode::asText)
.orElse("Model did not provide rationale");
return new ModelReply(answer, confidence, requiresEscalation, rationale);
return new ModelReply(
answer,
confidence,
requiresEscalation,
rationale,
promptTokens,
completionTokens,
totalTokens);
} catch (IOException ex) {
log.warn("Failed to parse model JSON response, returning raw text", ex);
return new ModelReply(raw, 0.0D, true, "Unable to parse JSON response");
return new ModelReply(
raw,
0.0D,
true,
"Unable to parse JSON response",
promptTokens,
completionTokens,
totalTokens);
}
}
private record ModelReply(
String answer, double confidence, boolean requiresEscalation, String rationale) {}
String answer,
double confidence,
boolean requiresEscalation,
String rationale,
long promptTokens,
long completionTokens,
long totalTokens) {}
private String sanitizeRemoteMessage(String message) {
if (!StringUtils.hasText(message)) {
@ -317,4 +355,8 @@ public class ChatbotConversationService {
}
return message.replaceAll("(?i)api[-_ ]?key\\s*=[^\\s]+", "api-key=***");
}
private long toLong(Integer value) {
return value == null ? 0L : value.longValue();
}
}

View File

@ -26,7 +26,9 @@ public class ChatbotFeatureProperties {
resolveProvider(chatbot.getModels().getProvider()),
chatbot.getModels().getPrimary(),
chatbot.getModels().getFallback(),
chatbot.getModels().getEmbedding());
chatbot.getModels().getEmbedding(),
chatbot.getModels().getTemperature(),
chatbot.getModels().getTopP());
return new ChatbotSettings(
chatbot.isEnabled(),
chatbot.isAlphaWarning(),
@ -42,7 +44,10 @@ public class ChatbotFeatureProperties {
chatbot.getCache().getMaxEntries(),
chatbot.getCache().getMaxDocumentCharacters()),
new ChatbotSettings.OcrSettings(chatbot.getOcr().isEnabledByDefault()),
new ChatbotSettings.AuditSettings(chatbot.getAudit().isEnabled()));
new ChatbotSettings.AuditSettings(chatbot.getAudit().isEnabled()),
new ChatbotSettings.UsageSettings(
chatbot.getUsage().getPerUserMonthlyTokens(),
chatbot.getUsage().getWarnAtRatio()));
}
public boolean isEnabled() {
@ -77,10 +82,16 @@ public class ChatbotFeatureProperties {
RagSettings rag,
CacheSettings cache,
OcrSettings ocr,
AuditSettings audit) {
AuditSettings audit,
UsageSettings usage) {
public record ModelSettings(
ModelProvider provider, String primary, String fallback, String embedding) {}
ModelProvider provider,
String primary,
String fallback,
String embedding,
double temperature,
double topP) {}
public record RagSettings(int chunkSizeTokens, int chunkOverlapTokens, int topK) {}
@ -90,6 +101,8 @@ public class ChatbotFeatureProperties {
public record AuditSettings(boolean enabled) {}
public record UsageSettings(long perUserMonthlyTokens, double warnAtRatio) {}
public enum ModelProvider {
OPENAI,
OLLAMA

View File

@ -5,10 +5,9 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
@ -17,7 +16,6 @@ import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
import stirling.software.proprietary.service.chatbot.exception.NoTextDetectedException;
@ -30,7 +28,8 @@ public class ChatbotIngestionService {
private final ChatbotCacheService cacheService;
private final ChatbotSessionRegistry sessionRegistry;
private final ChatbotFeatureProperties featureProperties;
private final EmbeddingModel embeddingModel;
private final VectorStore vectorStore;
private final ChatbotUsageService usageService;
public ChatbotSession ingest(ChatbotSessionCreateRequest request) {
ChatbotSettings settings = featureProperties.current();
@ -40,18 +39,23 @@ public class ChatbotIngestionService {
if (!request.isWarningsAccepted() && settings.alphaWarning()) {
throw new ChatbotException("Alpha warning must be accepted before use");
}
boolean hasText = StringUtils.hasText(request.getText());
if (!hasText) {
if (!StringUtils.hasText(request.getText())) {
throw new NoTextDetectedException(
"No text detected in document payload. Images are currently unsupported enable OCR to continue.");
}
long characterLimit = cacheService.getMaxDocumentCharacters();
long textCharacters = request.getText().length();
if (textCharacters > characterLimit) {
throw new ChatbotException(
"Document text exceeds maximum allowed characters: " + characterLimit);
}
String sessionId =
StringUtils.hasText(request.getSessionId())
? request.getSessionId()
: ChatbotSession.randomSessionId();
boolean imagesDetected = request.isImagesDetected();
long textCharacters = request.getText().length();
boolean ocrApplied = request.isOcrRequested();
Map<String, String> metadata = new HashMap<>();
if (request.getMetadata() != null) {
@ -63,23 +67,28 @@ public class ChatbotIngestionService {
"content.extractionSource", ocrApplied ? "ocr-text-layer" : "embedded-text-layer");
Map<String, String> immutableMetadata = Map.copyOf(metadata);
List<Document> documents =
buildDocuments(
sessionId, request.getDocumentId(), request.getText(), metadata, settings);
try {
vectorStore.add(documents);
} catch (RuntimeException ex) {
throw new ChatbotException(
"Failed to index document content in vector store: "
+ sanitizeRemoteMessage(ex.getMessage()),
ex);
}
String cacheKey =
cacheService.register(
sessionId,
request.getDocumentId(),
request.getText(),
immutableMetadata,
ocrApplied,
imagesDetected,
textCharacters);
List<String> chunkTexts =
chunkText(
request.getText(),
settings.rag().chunkSizeTokens(),
settings.rag().chunkOverlapTokens());
List<ChatbotTextChunk> chunks = embedChunks(sessionId, cacheKey, chunkTexts, metadata);
cacheService.attachChunks(cacheKey, chunks);
long estimatedTokens = Math.max(1L, Math.round(textCharacters / 4.0));
ChatbotSession session =
ChatbotSession.builder()
@ -90,93 +99,65 @@ public class ChatbotIngestionService {
.ocrRequested(ocrApplied)
.imageContentDetected(imagesDetected)
.textCharacters(textCharacters)
.estimatedTokens(estimatedTokens)
.warningsAccepted(request.isWarningsAccepted())
.alphaWarningRequired(settings.alphaWarning())
.cacheKey(cacheKey)
.createdAt(Instant.now())
.build();
session.setUsageSummary(
usageService.registerIngestion(session.getUserId(), estimatedTokens));
sessionRegistry.register(session);
log.info(
"Registered chatbot session {} for document {} with {} chunks",
"Registered chatbot session {} for document {} with {} RAG chunks",
sessionId,
request.getDocumentId(),
chunks.size());
documents.size());
return session;
}
private List<String> chunkText(String text, int chunkSizeTokens, int overlapTokens) {
String[] tokens = text.split("\\s+");
List<String> chunks = new ArrayList<>();
if (tokens.length == 0) {
return chunks;
private List<Document> buildDocuments(
String sessionId,
String documentId,
String text,
Map<String, String> metadata,
ChatbotSettings settings) {
List<Document> documents = new ArrayList<>();
if (!StringUtils.hasText(text)) {
return documents;
}
int effectiveChunk = Math.max(chunkSizeTokens, 1);
int effectiveOverlap = Math.max(Math.min(overlapTokens, effectiveChunk - 1), 0);
int chunkChars = Math.max(512, settings.rag().chunkSizeTokens() * 4);
int overlapChars = Math.max(64, settings.rag().chunkOverlapTokens() * 4);
int index = 0;
while (index < tokens.length) {
int end = Math.min(tokens.length, index + effectiveChunk);
String chunk = String.join(" ", java.util.Arrays.copyOfRange(tokens, index, end));
if (StringUtils.hasText(chunk)) {
chunks.add(chunk);
int order = 0;
while (index < text.length()) {
int end = Math.min(text.length(), index + chunkChars);
String chunk = text.substring(index, end).trim();
if (!chunk.isEmpty()) {
Document document = new Document(chunk);
document.getMetadata().putAll(metadata);
document.getMetadata().put("sessionId", sessionId);
document.getMetadata().put("documentId", documentId);
document.getMetadata().put("chunkOrder", Integer.toString(order));
documents.add(document);
order++;
}
if (end == tokens.length) {
if (end == text.length()) {
break;
}
index = end - effectiveOverlap;
if (index <= 0) {
index = end;
int nextIndex = end - overlapChars;
if (nextIndex <= index) {
nextIndex = end;
}
index = nextIndex;
}
return chunks;
}
private List<ChatbotTextChunk> embedChunks(
String sessionId,
String cacheKey,
List<String> chunkTexts,
Map<String, String> metadata) {
if (chunkTexts.isEmpty()) {
throw new ChatbotException("Unable to split document text into retrievable chunks");
if (documents.isEmpty()) {
throw new ChatbotException("Unable to split document text into searchable chunks");
}
EmbeddingResponse response;
try {
response = embeddingModel.embedForResponse(chunkTexts);
} catch (org.eclipse.jetty.client.HttpResponseException ex) {
throw new ChatbotException(
"Embedding provider rejected the request: "
+ sanitizeRemoteMessage(ex.getMessage()),
ex);
} catch (RuntimeException ex) {
throw new ChatbotException(
"Failed to compute embeddings for chatbot ingestion: "
+ sanitizeRemoteMessage(ex.getMessage()),
ex);
}
if (response.getResults().size() != chunkTexts.size()) {
throw new ChatbotException("Mismatch between chunks and embedding results");
}
List<ChatbotTextChunk> chunks = new ArrayList<>();
for (int i = 0; i < chunkTexts.size(); i++) {
String chunkId = sessionId + ":" + i + ":" + UUID.randomUUID();
float[] embeddingArray = response.getResults().get(i).getOutput();
List<Double> embedding = new ArrayList<>(embeddingArray.length);
for (float value : embeddingArray) {
embedding.add((double) value);
}
chunks.add(
ChatbotTextChunk.builder()
.id(chunkId)
.order(i)
.text(chunkTexts.get(i))
.embedding(embedding)
.build());
}
log.debug(
"Computed embeddings for session {} cacheKey {} ({} vectors)",
sessionId,
cacheKey,
chunks.size());
return chunks;
return documents;
}
private String sanitizeRemoteMessage(String message) {

View File

@ -1,20 +1,16 @@
package stirling.software.proprietary.service.chatbot;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
import stirling.software.proprietary.model.chatbot.ChatbotTextChunk;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@ -24,69 +20,53 @@ import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
public class ChatbotRetrievalService {
private final ChatbotCacheService cacheService;
private final EmbeddingModel embeddingModel;
private final VectorStore vectorStore;
public List<ChatbotTextChunk> retrieveTopK(
String sessionId, String query, ChatbotSettings settings) {
ChatbotDocumentCacheEntry entry =
cacheService
.resolveBySessionId(sessionId)
.orElseThrow(() -> new ChatbotException("Unknown chatbot session"));
List<ChatbotTextChunk> chunks = entry.getChunks();
if (CollectionUtils.isEmpty(chunks)) {
throw new ChatbotException("Chatbot cache does not contain pre-computed chunks");
public List<Document> retrieveTopK(String sessionId, String query, ChatbotSettings settings) {
cacheService
.resolveBySessionId(sessionId)
.orElseThrow(() -> new ChatbotException("Unknown chatbot session"));
int topK = Math.max(settings.rag().topK(), 1);
String sanitizedQuery = StringUtils.hasText(query) ? query : "";
String filterExpression = "sessionId == '" + escape(sessionId) + "'";
SearchRequest searchRequest =
SearchRequest.builder()
.query(sanitizedQuery)
.topK(topK)
.filterExpression(filterExpression)
.build();
List<Document> results;
try {
results = vectorStore.similaritySearch(searchRequest);
} catch (RuntimeException ex) {
throw new ChatbotException(
"Failed to perform vector similarity search: "
+ sanitizeRemoteMessage(ex.getMessage()),
ex);
}
List<Double> queryEmbedding = computeQueryEmbedding(query);
List<ScoredChunk> scoredChunks = new ArrayList<>();
for (ChatbotTextChunk chunk : chunks) {
if (CollectionUtils.isEmpty(chunk.getEmbedding())) {
log.warn("Chunk {} missing embedding, skipping", chunk.getId());
continue;
}
double score = cosineSimilarity(queryEmbedding, chunk.getEmbedding());
scoredChunks.add(new ScoredChunk(chunk, score));
results =
results.stream()
.filter(
doc ->
sessionId.equals(
doc.getMetadata().getOrDefault("sessionId", "")))
.limit(topK)
.toList();
if (results.isEmpty()) {
throw new ChatbotException("No context available for this chatbot session");
}
return scoredChunks.stream()
.sorted(Comparator.comparingDouble(ScoredChunk::score).reversed())
.limit(Math.max(settings.rag().topK(), 1))
.map(ScoredChunk::chunk)
.toList();
return results;
}
private List<Double> computeQueryEmbedding(String query) {
EmbeddingResponse response = embeddingModel.embedForResponse(List.of(query));
float[] embeddingArray =
Optional.ofNullable(response.getResults().stream().findFirst().orElse(null))
.map(org.springframework.ai.embedding.Embedding::getOutput)
.orElseThrow(
() -> new ChatbotException("Failed to compute query embedding"));
List<Double> embedding = new ArrayList<>(embeddingArray.length);
for (float value : embeddingArray) {
embedding.add((double) value);
private String sanitizeRemoteMessage(String message) {
if (!StringUtils.hasText(message)) {
return "unexpected provider error";
}
return embedding;
return message.replaceAll("(?i)api[-_ ]?key\\s*=[^\\s]+", "api-key=***");
}
private double cosineSimilarity(List<Double> v1, List<Double> v2) {
int size = Math.min(v1.size(), v2.size());
if (size == 0) {
return -1.0;
}
double dot = 0.0;
double mag1 = 0.0;
double mag2 = 0.0;
for (int i = 0; i < size; i++) {
double a = v1.get(i);
double b = v2.get(i);
dot += a * b;
mag1 += a * a;
mag2 += b * b;
}
if (mag1 == 0.0 || mag2 == 0.0) {
return -1.0;
}
return dot / (Math.sqrt(mag1) * Math.sqrt(mag2));
private String escape(String value) {
return value.replace("'", "\\'");
}
private record ScoredChunk(ChatbotTextChunk chunk, double score) {}
}

View File

@ -4,6 +4,7 @@ import java.util.HashMap;
import java.util.Map;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -12,6 +13,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
import stirling.software.proprietary.security.service.UserService;
import stirling.software.proprietary.service.AuditService;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
@ -26,8 +28,12 @@ public class ChatbotService {
private final ChatbotCacheService cacheService;
private final ChatbotFeatureProperties featureProperties;
private final AuditService auditService;
private final UserService userService;
public ChatbotSession createSession(ChatbotSessionCreateRequest request) {
if (!StringUtils.hasText(request.getUserId())) {
request.setUserId(userService.getCurrentUsername());
}
ChatbotSession session = ingestionService.ingest(request);
log.debug("Chatbot session {} initialised", session.getSessionId());
audit(

View File

@ -0,0 +1,106 @@
package stirling.software.proprietary.service.chatbot;
import java.time.YearMonth;
import java.time.ZoneOffset;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.proprietary.model.chatbot.ChatbotUsageSummary;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
@Service
@RequiredArgsConstructor
@Slf4j
public class ChatbotUsageService {
private final ChatbotFeatureProperties featureProperties;
private final Map<String, UsageWindow> usageByUser = new ConcurrentHashMap<>();
public ChatbotUsageSummary registerIngestion(String userId, long estimatedTokens) {
return incrementUsage(userId, Math.max(estimatedTokens, 0L));
}
public ChatbotUsageSummary registerGeneration(
String userId, long promptTokens, long completionTokens) {
long total = Math.max(promptTokens + completionTokens, 0L);
return incrementUsage(userId, total);
}
public ChatbotUsageSummary currentUsage(String userId) {
String key = normalizeUserId(userId);
UsageWindow window = usageByUser.get(key);
if (window == null) {
return buildSummary(key, 0L, 0L);
}
return buildSummary(key, window.tokens.get(), 0L);
}
private ChatbotUsageSummary incrementUsage(String userId, long deltaTokens) {
String key = normalizeUserId(userId);
YearMonth now = YearMonth.now(ZoneOffset.UTC);
UsageWindow window =
usageByUser.compute(
key,
(ignored, existing) -> {
if (existing == null || !existing.window.equals(now)) {
existing = new UsageWindow(now);
}
if (deltaTokens > 0) {
existing.tokens.addAndGet(deltaTokens);
}
return existing;
});
return buildSummary(key, window.tokens.get(), deltaTokens);
}
private ChatbotUsageSummary buildSummary(String userKey, long consumed, long deltaTokens) {
ChatbotSettings settings = featureProperties.current();
long allocation = Math.max(settings.usage().perUserMonthlyTokens(), 1L);
double ratio = allocation == 0 ? 1.0 : (double) consumed / allocation;
long remaining = Math.max(allocation - consumed, 0L);
boolean limitExceeded = consumed > allocation;
boolean nearingLimit = ratio >= settings.usage().warnAtRatio();
return ChatbotUsageSummary.builder()
.allocatedTokens(allocation)
.consumedTokens(consumed)
.remainingTokens(remaining)
.usageRatio(Math.min(ratio, 1.0))
.nearingLimit(nearingLimit)
.limitExceeded(limitExceeded)
.lastIncrementTokens(deltaTokens)
.window(currentWindowDescription(userKey))
.build();
}
private String currentWindowDescription(String userKey) {
UsageWindow window = usageByUser.get(userKey);
if (window == null) {
return YearMonth.now(ZoneOffset.UTC).toString();
}
return window.window.toString();
}
private String normalizeUserId(String userId) {
if (!StringUtils.hasText(userId)) {
return "anonymous";
}
return userId.trim().toLowerCase();
}
private static final class UsageWindow {
private final YearMonth window;
private final AtomicLong tokens = new AtomicLong();
private UsageWindow(YearMonth window) {
this.window = window;
}
}
}

View File

@ -1,7 +1,6 @@
package stirling.software.proprietary.service.chatbot;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.Map;
@ -11,7 +10,6 @@ import org.junit.jupiter.api.Test;
import stirling.software.common.model.ApplicationProperties;
import stirling.software.proprietary.model.chatbot.ChatbotDocumentCacheEntry;
import stirling.software.proprietary.service.chatbot.exception.ChatbotException;
class ChatbotCacheServiceTest {
@ -34,23 +32,6 @@ class ChatbotCacheServiceTest {
properties.setPremium(premium);
}
@Test
void registerRejectsOversizedText() {
ChatbotCacheService cacheService = new ChatbotCacheService(properties);
String longText = "a".repeat(51);
assertThrows(
ChatbotException.class,
() ->
cacheService.register(
"session",
"doc",
longText,
Map.of(),
false,
false,
longText.length()));
}
@Test
void registerAndResolveSession() {
ChatbotCacheService cacheService = new ChatbotCacheService(properties);
@ -58,7 +39,6 @@ class ChatbotCacheServiceTest {
cacheService.register(
"session1",
"doc1",
"hello world",
Map.of("title", "Sample"),
false,
false,

View File

@ -22,6 +22,7 @@ import stirling.software.proprietary.model.chatbot.ChatbotQueryRequest;
import stirling.software.proprietary.model.chatbot.ChatbotResponse;
import stirling.software.proprietary.model.chatbot.ChatbotSession;
import stirling.software.proprietary.model.chatbot.ChatbotSessionCreateRequest;
import stirling.software.proprietary.security.service.UserService;
import stirling.software.proprietary.service.AuditService;
import stirling.software.proprietary.service.chatbot.ChatbotFeatureProperties.ChatbotSettings;
@ -34,6 +35,7 @@ class ChatbotServiceTest {
@Mock private ChatbotCacheService cacheService;
@Mock private ChatbotFeatureProperties featureProperties;
@Mock private AuditService auditService;
@Mock private UserService userService;
@InjectMocks private ChatbotService chatbotService;
@ -52,11 +54,14 @@ class ChatbotServiceTest {
ChatbotSettings.ModelProvider.OPENAI,
"gpt-5-nano",
"gpt-5-mini",
"embed"),
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),
new ChatbotSettings.OcrSettings(false),
new ChatbotSettings.AuditSettings(true));
new ChatbotSettings.AuditSettings(true),
new ChatbotSettings.UsageSettings(100000L, 0.7D));
auditDisabledSettings =
new ChatbotSettings(
@ -68,11 +73,14 @@ class ChatbotServiceTest {
ChatbotSettings.ModelProvider.OPENAI,
"gpt-5-nano",
"gpt-5-mini",
"embed"),
"embed",
0.2D,
0.95D),
new ChatbotSettings.RagSettings(512, 128, 4),
new ChatbotSettings.CacheSettings(60, 10, 1000),
new ChatbotSettings.OcrSettings(false),
new ChatbotSettings.AuditSettings(false));
new ChatbotSettings.AuditSettings(false),
new ChatbotSettings.UsageSettings(100000L, 0.7D));
}
@Test
@ -86,6 +94,7 @@ class ChatbotServiceTest {
.build();
when(ingestionService.ingest(any())).thenReturn(session);
when(featureProperties.current()).thenReturn(auditEnabledSettings);
when(userService.getCurrentUsername()).thenReturn("tester");
chatbotService.createSession(
ChatbotSessionCreateRequest.builder().text("abc").warningsAccepted(true).build());
@ -97,6 +106,7 @@ class ChatbotServiceTest {
payloadCaptor.capture());
Map<String, Object> payload = payloadCaptor.getValue();
verify(cacheService, times(0)).invalidateSession(any());
verify(userService).getCurrentUsername();
org.junit.jupiter.api.Assertions.assertEquals("session-1", payload.get("sessionId"));
}

View File

@ -30,6 +30,7 @@ import { runOcrForChat } from '@app/services/chatbotOcrService';
import {
ChatbotMessageResponse,
ChatbotSessionInfo,
ChatbotUsageSummary,
createChatbotSession,
sendChatbotPrompt,
} from '@app/services/chatbotService';
@ -81,6 +82,7 @@ const ChatbotDrawer = () => {
const [pendingOcrRetry, setPendingOcrRetry] = useState(false);
const scrollViewportRef = useRef<HTMLDivElement>(null);
const [panelAnchor, setPanelAnchor] = useState<{ right: number; top: number } | null>(null);
const usageAlertState = useRef<'none' | 'warned' | 'limit'>('none');
const selectedFile = useMemo<StirlingFile | undefined>(
() => files.find((file) => file.fileId === selectedFileId),
@ -120,6 +122,39 @@ const ChatbotDrawer = () => {
}
}, [messages, isOpen]);
useEffect(() => {
usageAlertState.current = 'none';
}, [sessionInfo?.sessionId]);
const maybeShowUsageWarning = (usage?: ChatbotUsageSummary | null) => {
if (!usage) {
return;
}
if (usage.limitExceeded && usageAlertState.current !== 'limit') {
usageAlertState.current = 'limit';
show({
alertType: 'warning',
title: t('chatbot.usage.limitReachedTitle', 'Chatbot limit reached'),
body: t(
'chatbot.usage.limitReachedBody',
'You have exceeded the current monthly allocation for the chatbot. Further responses may be throttled.'
),
});
return;
}
if (usage.nearingLimit && usageAlertState.current === 'none') {
usageAlertState.current = 'warned';
show({
alertType: 'warning',
title: t('chatbot.usage.nearingLimitTitle', 'Approaching usage limit'),
body: t(
'chatbot.usage.nearingLimitBody',
'You are nearing your monthly chatbot allocation. Consider limiting very large requests.'
),
});
}
};
useEffect(() => {
if (sessionInfo && sessionInfo.documentId !== selectedFileId) {
setSessionInfo(null);
@ -245,6 +280,7 @@ const ChatbotDrawer = () => {
);
setSessionInfo(response);
maybeShowUsageWarning(response.usageSummary);
setContextStats({
pageCount: extractionResult.pageCount,
characterCount: extractionResult.characterCount,
@ -295,6 +331,7 @@ const ChatbotDrawer = () => {
prompt: trimmedPrompt,
allowEscalation: true,
});
maybeShowUsageWarning(reply.usageSummary);
setWarnings(reply.warnings ?? []);
const assistant = convertAssistantMessage(reply);
setMessages((prev) => [...prev, assistant]);

View File

@ -87,6 +87,34 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
const initialDelay = retryOptions?.initialDelay ?? 1000;
const fetchConfig = useCallback(async (force = false) => {
// First check if user has a JWT token - if not, they're not authenticated
const hasJWT = localStorage.getItem('stirling_jwt');
// Check if on auth page
// Need to check for paths with or without base path
const pathname = window.location.pathname;
const isAuthPage = pathname.endsWith('/login') ||
pathname.endsWith('/signup') ||
pathname.endsWith('/auth/callback') ||
pathname.includes('/auth/') ||
pathname.includes('/invite/');
// Skip config fetch if:
// 1. On auth page, OR
// 2. No JWT token (not authenticated) and not forcing
if (isAuthPage || (!hasJWT && !force)) {
console.debug('[AppConfig] Skipping config fetch:', {
reason: isAuthPage ? 'On auth page' : 'No JWT token',
pathname,
hasJWT: !!hasJWT,
force
});
setLoading(false);
setConfig({ enableLogin: true });
setHasResolvedConfig(true);
return;
}
// Prevent duplicate fetches unless forced
if (!force && fetchCount > 0) {
console.debug('[AppConfig] Already fetched, skipping');
@ -109,6 +137,16 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
console.log('[AppConfig] Fetching app config...');
}
// GUARD: Only make the API call if user has JWT token
const currentJWT = localStorage.getItem('stirling_jwt');
if (!currentJWT && !force) {
console.debug('[AppConfig] No JWT token, skipping API call entirely');
setConfig({ enableLogin: true });
setHasResolvedConfig(true);
setLoading(false);
return;
}
// apiClient automatically adds JWT header if available via interceptors
const response = await apiClient.get<AppConfig>('/api/v1/config/app-config', !isBlockingMode ? { suppressErrorToast: true } : undefined);
const data = response.data;
@ -155,12 +193,31 @@ export const AppConfigProvider: React.FC<AppConfigProviderProps> = ({
}, [fetchCount, hasResolvedConfig, isBlockingMode, maxRetries, initialDelay]);
useEffect(() => {
// Skip fetching config on login and auth callback pages
// Need to check for paths with or without base path
const pathname = window.location.pathname;
const isAuthPage = pathname.endsWith('/login') ||
pathname.endsWith('/signup') ||
pathname.endsWith('/auth/callback') ||
pathname.includes('/auth/') ||
pathname.includes('/invite/');
if (isAuthPage) {
console.debug('[AppConfig] On auth page, skipping config fetch in useEffect');
console.debug('[AppConfig] Current pathname:', pathname);
setLoading(false);
// Set minimal config for auth pages
setConfig({ enableLogin: true });
setHasResolvedConfig(true);
return;
}
// Always try to fetch config to check if login is disabled
// The endpoint should be public and return proper JSON
if (autoFetch) {
fetchConfig();
}
}, [autoFetch, fetchConfig]);
}, [autoFetch]); // Remove fetchConfig from deps to prevent re-runs
// Listen for JWT availability (triggered on login/signup)
useEffect(() => {

View File

@ -1,5 +1,16 @@
import apiClient from '@app/services/apiClient';
export interface ChatbotUsageSummary {
allocatedTokens: number;
consumedTokens: number;
remainingTokens: number;
usageRatio: number;
nearingLimit: boolean;
limitExceeded: boolean;
lastIncrementTokens: number;
window?: string;
}
export interface ChatbotSessionPayload {
sessionId?: string;
documentId: string;
@ -17,8 +28,11 @@ export interface ChatbotSessionInfo {
ocrRequested: boolean;
maxCachedCharacters: number;
createdAt: string;
textCharacters: number;
estimatedTokens: number;
warnings?: string[];
metadata?: Record<string, string>;
usageSummary?: ChatbotUsageSummary;
}
export interface ChatbotQueryPayload {
@ -37,6 +51,10 @@ export interface ChatbotMessageResponse {
cacheHit?: boolean;
warnings?: string[];
metadata?: Record<string, unknown>;
promptTokens?: number;
completionTokens?: number;
totalTokens?: number;
usageSummary?: ChatbotUsageSummary;
}
export async function createChatbotSession(payload: ChatbotSessionPayload) {
@ -48,4 +66,3 @@ export async function sendChatbotPrompt(payload: ChatbotQueryPayload) {
const { data } = await apiClient.post<ChatbotMessageResponse>('/api/v1/internal/chatbot/query', payload);
return data;
}

View File

@ -95,6 +95,30 @@ export function AuthProvider({ children }: { children: ReactNode }) {
try {
console.debug('[Auth] Initializing auth...');
// Skip auth check if we're on auth pages
// Need to check for paths with or without base path
const pathname = window.location.pathname;
const isAuthPage = pathname.endsWith('/login') ||
pathname.endsWith('/signup') ||
pathname.endsWith('/auth/callback') ||
pathname.includes('/auth/') ||
pathname.includes('/invite/');
if (isAuthPage) {
console.log('[Auth] On auth page, completely skipping session check');
console.log('[Auth] Current path:', pathname);
setLoading(false);
return;
}
// GUARD: Check if JWT exists before making session call
const hasJWT = localStorage.getItem('stirling_jwt');
if (!hasJWT) {
console.debug('[Auth] No JWT token found, skipping session check');
setLoading(false);
return;
}
// Skip config check entirely - let the app handle login state
// The config will be fetched by useAppConfig when needed
const { data, error } = await springAuth.getSession();
@ -126,6 +150,22 @@ export function AuthProvider({ children }: { children: ReactNode }) {
initializeAuth();
// Listen for JWT availability after OAuth callback or login
const handleJwtAvailable = async () => {
console.debug('[Auth] JWT available event detected, loading session');
try {
const { data, error } = await springAuth.getSession();
if (!error && data.session) {
console.debug('[Auth] Session loaded after JWT available:', data.session);
setSession(data.session);
setLoading(false);
}
} catch (err) {
console.error('[Auth] Error loading session after JWT available:', err);
}
};
window.addEventListener('jwt-available', handleJwtAvailable);
// Subscribe to auth state changes
const { data: { subscription } } = springAuth.onAuthStateChange(
async (event: AuthChangeEvent, newSession: Session | null) => {
@ -163,6 +203,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
return () => {
mounted = false;
subscription.unsubscribe();
window.removeEventListener('jwt-available', handleJwtAvailable);
};
}, []);

View File

@ -120,25 +120,32 @@ class SpringAuthClient {
async getSession(): Promise<{ data: { session: Session | null }; error: AuthError | null }> {
try {
// Get JWT from localStorage
console.log('[SpringAuth] getSession: Checking localStorage for JWT...');
const token = localStorage.getItem('stirling_jwt');
// Log all localStorage keys for debugging
console.log('[SpringAuth] All localStorage keys:', Object.keys(localStorage));
if (!token) {
console.debug('[SpringAuth] getSession: No JWT in localStorage');
console.warn('[SpringAuth] getSession: No JWT found in localStorage!');
console.warn('[SpringAuth] This will cause logout on refresh. Make sure JWT is saved after login.');
return { data: { session: null }, error: null };
}
console.log('[SpringAuth] getSession: Found JWT in localStorage, length:', token.length);
// Verify with backend
// Note: We pass the token explicitly here, overriding the interceptor's default
console.debug('[SpringAuth] getSession: Verifying JWT with /api/v1/auth/me');
console.log('[SpringAuth] getSession: Verifying JWT with /api/v1/auth/me');
const response = await apiClient.get('/api/v1/auth/me', {
headers: {
'Authorization': `Bearer ${token}`,
},
});
console.debug('[SpringAuth] /me response status:', response.status);
console.log('[SpringAuth] /me response status:', response.status);
const data = response.data;
console.debug('[SpringAuth] /me response data:', data);
console.log('[SpringAuth] /me response data:', data);
// Create session object
const session: Session = {
@ -148,7 +155,8 @@ class SpringAuthClient {
expires_at: Date.now() + 3600 * 1000,
};
console.debug('[SpringAuth] getSession: Session retrieved successfully');
console.log('[SpringAuth] getSession: ✓ Session retrieved successfully');
console.log('[SpringAuth] User is logged in as:', data.user?.email || data.user?.username);
return { data: { session }, error: null };
} catch (error: unknown) {
console.error('[SpringAuth] getSession error:', error);
@ -187,9 +195,22 @@ class SpringAuthClient {
const data = response.data;
const token = data.session.access_token;
// Store JWT in localStorage
// Store JWT in localStorage - CRITICAL for persistence
localStorage.setItem('stirling_jwt', token);
console.log('[SpringAuth] JWT stored in localStorage');
console.log('[SpringAuth] JWT stored in localStorage after password login');
console.log('[SpringAuth] JWT token length:', token ? token.length : 0);
// Verify it was actually saved
const savedToken = localStorage.getItem('stirling_jwt');
if (!savedToken) {
console.error('[SpringAuth] CRITICAL: JWT was not saved to localStorage!');
// Try again
localStorage.setItem('stirling_jwt', token);
} else if (savedToken !== token) {
console.error('[SpringAuth] CRITICAL: Saved token differs from received token!');
} else {
console.log('[SpringAuth] ✓ Verified JWT is correctly saved in localStorage');
}
// Dispatch custom event for other components to react to JWT availability
window.dispatchEvent(new CustomEvent('jwt-available'));
@ -317,19 +338,37 @@ class SpringAuthClient {
});
const data = response.data;
const token = data.session.access_token;
// Handle different response structures - the API might return the token directly or nested
const token = data?.session?.access_token || data?.access_token || data?.token;
if (!token) {
console.error('[SpringAuth] refreshSession: No access token in response:', data);
throw new Error('No access token received from refresh endpoint');
}
// Update local storage with new token
localStorage.setItem('stirling_jwt', token);
console.log('[SpringAuth] refreshSession: New JWT stored in localStorage');
// Verify it was saved
const savedToken = localStorage.getItem('stirling_jwt');
if (savedToken !== token) {
console.error('[SpringAuth] CRITICAL: JWT was not properly saved during refresh!');
} else {
console.log('[SpringAuth] refreshSession: ✓ JWT refreshed and verified in localStorage');
}
// Dispatch custom event for other components to react to JWT availability
window.dispatchEvent(new CustomEvent('jwt-available'));
// Build session object, handling different response structures
const expires_in = data?.session?.expires_in || data?.expires_in || 3600; // Default to 1 hour
const session: Session = {
user: data.user,
user: data?.user || data?.session?.user || null,
access_token: token,
expires_in: data.session.expires_in,
expires_at: Date.now() + data.session.expires_in * 1000,
expires_in: expires_in,
expires_at: Date.now() + expires_in * 1000,
};
// Notify listeners

View File

@ -1,6 +1,5 @@
import { useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { useAuth } from '@app/auth/UseSession';
/**
* OAuth Callback Handler
@ -11,7 +10,6 @@ import { useAuth } from '@app/auth/UseSession';
*/
export default function AuthCallback() {
const navigate = useNavigate();
const { refreshSession } = useAuth();
useEffect(() => {
const handleCallback = async () => {
@ -32,17 +30,31 @@ export default function AuthCallback() {
return;
}
// Store JWT in localStorage
// Store JWT in localStorage - CRITICAL for persistence
localStorage.setItem('stirling_jwt', token);
console.log('[AuthCallback] JWT stored in localStorage');
console.log('[AuthCallback] JWT stored in localStorage after OAuth');
console.log('[AuthCallback] JWT token length:', token.length);
// Verify it was actually saved
const savedToken = localStorage.getItem('stirling_jwt');
if (!savedToken) {
console.error('[AuthCallback] CRITICAL: JWT was not saved to localStorage!');
// Try again
localStorage.setItem('stirling_jwt', token);
} else if (savedToken !== token) {
console.error('[AuthCallback] CRITICAL: Saved token differs from received token!');
} else {
console.log('[AuthCallback] ✓ Verified JWT is correctly saved in localStorage');
}
// Dispatch custom event for other components to react to JWT availability
window.dispatchEvent(new CustomEvent('jwt-available'))
// This will trigger the auth provider to load the session with the new JWT
window.dispatchEvent(new CustomEvent('jwt-available'));
// Refresh session to load user info into state
await refreshSession();
// Small delay to ensure event is processed
await new Promise(resolve => setTimeout(resolve, 100));
console.log('[AuthCallback] Session refreshed, redirecting to home');
console.log('[AuthCallback] JWT saved, redirecting to home');
// Clear the hash from URL and redirect to home page
navigate('/', { replace: true });
@ -56,7 +68,7 @@ export default function AuthCallback() {
};
handleCallback();
}, [navigate, refreshSession]);
}, [navigate]);
return (
<div style={{

View File

@ -118,8 +118,12 @@ export default function Login() {
setError(error.message);
} else if (user && session) {
console.log('[Login] Email sign in successful');
// Auth state will update automatically and Landing will redirect to home
// No need to navigate manually here
// Dispatch event to trigger auth state update
window.dispatchEvent(new CustomEvent('jwt-available'));
// Navigate to home page
setTimeout(() => {
navigate('/', { replace: true });
}, 100); // Small delay to ensure auth state updates
}
} catch (err) {
console.error('[Login] Unexpected error:', err);

View File

@ -36,19 +36,16 @@ export default defineConfig(({ mode }) => {
target: 'http://localhost:8080',
changeOrigin: true,
secure: false,
xfwd: true,
},
'/oauth2': {
target: 'http://localhost:8080',
changeOrigin: true,
secure: false,
xfwd: true,
},
'/login/oauth2': {
target: 'http://localhost:8080',
changeOrigin: true,
secure: false,
xfwd: true,
},
},
},